Fix attention parity for GPT-2 (#8549)

* Use persistent softmax to parity with huggingface
* fix undirectional mask logic
* add test
This commit is contained in:
Tianlei Wu 2021-07-30 16:49:20 -07:00 committed by GitHub
parent 816ad86d14
commit 330b8e74bd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 680 additions and 74 deletions

View file

@ -112,11 +112,12 @@ class AttentionCPUBase : public AttentionBase {
{
if (mask_data != nullptr) {
PrepareMask(mask_index, mask_index_dims, mask_data, is_unidirectional_, batch_size, sequence_length, past_sequence_length);
} else { // no any mask
memset(attention_probs, 0, static_cast<size_t>(batch_size) * num_heads_ * sequence_length * all_sequence_length * sizeof(T));
// Convert attention mask data from int to float (0 to -10000.0f). The mask_data shape is BxSxS*.
PrepareMask(mask_index, mask_index_dims, mask_data, batch_size, sequence_length, past_sequence_length);
}
memset(attention_probs, 0, static_cast<size_t>(batch_size) * num_heads_ * sequence_length * all_sequence_length * sizeof(T));
const int loop_len = batch_size * num_heads_;
const float alpha = 1.0f / sqrt(static_cast<float>(head_size));
@ -127,32 +128,45 @@ class AttentionCPUBase : public AttentionBase {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const std::ptrdiff_t batch_index = i / num_heads_;
// broadcast mask data: (Bx)SxS* -> (BxNx)SxS*
if (mask_data != nullptr) {
const T* broadcast_data_src = reinterpret_cast<T*>(mask_data) + batch_index * sequence_length * all_sequence_length;
T* broadcast_data_dest = reinterpret_cast<T*>(attention_probs) + sequence_length * all_sequence_length * i;
memcpy(broadcast_data_dest, broadcast_data_src, sequence_length * all_sequence_length * sizeof(T));
}
const T* k = K + input_chunk_length * i;
if (nullptr != present) {
// concatenate past_K and K : (BxNx)S'xH, (BxNx)SxH -> (BxNx)S*xH
// Concatenate past_K and K : (BxNx)S'xH, (BxNx)SxH -> (BxNx)S*xH
k = ConcatStateChunk(past, k, present, past_chunk_length, present_chunk_length, i);
}
// gemm
int offset = sequence_length * all_sequence_length * static_cast<int>(i);
T* output = reinterpret_cast<T*>(attention_probs) + offset;
// Compute Q*K'
// original transposed each iteration
// A: Q (B x N x) S x H (B x N x) S x H S x H
// B: K' (B x N x) S* x H (B x N x) H x S* H x S*
// C: attention_probs (B x N x) S x S* (B x N x) S x S* S x S*
math::Gemm<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, all_sequence_length, head_size, alpha,
Q + input_chunk_length * i, k, 1.0,
reinterpret_cast<T*>(attention_probs) + sequence_length * all_sequence_length * i, nullptr);
output, nullptr);
// Apply unidirectional mask and set future words to -10000.0f.
if (is_unidirectional_) {
for (int s = 0; s < sequence_length - 1; s++) {
for (int t = past_sequence_length + s + 1; t < all_sequence_length; t++) {
output[s * all_sequence_length + t] = static_cast<T>(-10000.0f);
}
}
}
// Apply attention mask
if (mask_data != nullptr) {
const T* attention_mask = reinterpret_cast<T*>(mask_data) + batch_index * sequence_length * all_sequence_length;
for (int j = 0; j < sequence_length * all_sequence_length ; j++) {
output[j] += attention_mask[j];
}
}
if (extra_add_qk_data != nullptr) {
int extra_add_qk_offset = static_cast<int>(i) * sequence_length * all_sequence_length;
const T* extra = extra_add_qk_data + offset;
for (int j = 0; j < sequence_length * all_sequence_length ; j++) {
attention_probs[extra_add_qk_offset+j] += extra_add_qk_data[extra_add_qk_offset + j];
output[j] += extra[j];
}
}
}

View file

@ -64,7 +64,6 @@ template <typename T>
void PrepareMask(const int32_t* mask_index,
const std::vector<int64_t>* mask_index_dims,
T* mask_data,
bool is_unidirectional,
int batch_size,
int sequence_length,
int past_sequence_length) {
@ -84,18 +83,6 @@ void PrepareMask(const int32_t* mask_index,
for (int i = 0; i < batch_size * sequence_length * all_sequence_length; i++) {
p_mask[i] = (mask_index[i] > 0) ? static_cast<T>(0.0f) : static_cast<T>(-10000.0f);
}
if (is_unidirectional) {
for (int b_i = 0; b_i < batch_size; b_i++) {
for (int s_i = 0; s_i < sequence_length - 1; s_i++) {
for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) {
p_mask[s_i * all_sequence_length + m_i] += static_cast<T>(-10000.0f);
}
}
p_mask += sequence_length * all_sequence_length;
}
}
return;
}
@ -135,15 +122,6 @@ void PrepareMask(const int32_t* mask_index,
memcpy(p_mask + s_i * all_sequence_length, p_mask, all_sequence_length * sizeof(T));
}
// Apply unidirectional mask.
if (is_unidirectional) {
for (int s_i = 0; s_i < sequence_length - 1; s_i++) {
for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) {
p_mask[s_i * all_sequence_length + m_i] += static_cast<T>(-10000.0f);
}
}
}
p_mask += sequence_length * all_sequence_length;
}
}

View file

@ -27,6 +27,7 @@ limitations under the License.
#include "core/providers/cuda/shared_inc/fpgeneric.h"
#include "attention_impl.h"
#include "attention_softmax.h"
#include "transformer_common.h"
using namespace onnxruntime::cuda;
using namespace cub;
@ -65,7 +66,7 @@ bool QkvToContext(
const int batch_size, const int sequence_length, const int num_heads, const int head_size, const size_t element_size,
const T* input, T* output, T* workspace,
const int* mask_index, const std::vector<int64_t>* mask_index_dims,
bool is_unidirectional, int past_sequence_length, const T* past, T* present) {
bool is_unidirectional, int past_sequence_length, const T* past, T* present, bool use_persistent_softmax) {
const int all_sequence_length = past_sequence_length + sequence_length;
const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, all_sequence_length);
T* scratch1 = workspace;
@ -127,8 +128,11 @@ bool QkvToContext(
if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask
const int mask_dimension = static_cast<int>(mask_index_dims->size());
const int64_t max_sequence_length = mask_dimension == 4 ? mask_index_dims->at(3) : 0;
T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score if persistent softmax is selected.
if (!ComputeSoftmaxWithRawMask<T>(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, scratch1, scratch2, is_unidirectional,
rsqrt_head_size, mask_dimension, static_cast<int>(max_sequence_length))) {
rsqrt_head_size, mask_dimension, static_cast<int>(max_sequence_length),
use_persistent_softmax, persistent_softmax_workspace)) {
return false;
}
} else if (nullptr != mask_index) { // 1d mask index
@ -173,18 +177,27 @@ bool LaunchAttentionKernel(
int past_sequence_length,
const void* past,
void* present) {
// GPT-2 model is more sensitive on parity since error will accumulate in text generation.
// So use persistent softmax for GPT-2 model by default.
// For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 or 2 could enable or disable it explicitly.
const TransformerOptions* options = TransformerOptions::GetInstance();
bool use_persistent_softmax = (is_unidirectional || options->IsPrecisionMode()) && !options->DisablePersistentSoftmax();
if (element_size == 2) {
return QkvToContext(prop, cublas, stream,
batch_size, sequence_length, num_heads, head_size, element_size,
reinterpret_cast<const half*>(input), reinterpret_cast<half*>(output), reinterpret_cast<half*>(workspace),
mask_index, mask_index_dims, is_unidirectional,
past_sequence_length, reinterpret_cast<const half*>(past), reinterpret_cast<half*>(present));
past_sequence_length, reinterpret_cast<const half*>(past), reinterpret_cast<half*>(present),
use_persistent_softmax);
} else {
return QkvToContext(prop, cublas, stream,
batch_size, sequence_length, num_heads, head_size, element_size,
reinterpret_cast<const float*>(input), reinterpret_cast<float*>(output), reinterpret_cast<float*>(workspace),
mask_index, mask_index_dims, is_unidirectional,
past_sequence_length, reinterpret_cast<const float*>(past), reinterpret_cast<float*>(present));
past_sequence_length, reinterpret_cast<const float*>(past), reinterpret_cast<float*>(present),
use_persistent_softmax);
}
}

View file

@ -19,11 +19,13 @@ limitations under the License.
#pragma once
#include <type_traits>
#include <cub/cub.cuh>
#include <cuda_fp16.h>
#include <math_constants.h>
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/math/softmax.h"
using namespace onnxruntime::cuda;
using namespace cub;
@ -165,9 +167,10 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
const T* input,
T* output,
const bool is_unidirectional,
const float scalar,
const float rsqrt_head_size,
const int mask_dimension,
const int max_sequence_length) {
const int max_sequence_length,
const bool skip_softmax) {
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmp_storage;
@ -179,30 +182,36 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
float thread_data = -CUDART_INF_F;
if (threadIdx.x < all_sequence_length) {
const int batch_index = blockIdx.y;
thread_data = float(input[index]) * rsqrt_head_size;
const int sequence_index = blockIdx.x % sequence_length;
if (is_unidirectional) {
int from_index = all_sequence_length - sequence_length + sequence_index; // offset of from token in all sequence length.
if (threadIdx.x > from_index) {
thread_data = -10000.0f;
}
}
int mask_offset = 0;
const int batch_index = blockIdx.y;
if (mask_dimension == 2) {
mask_offset = batch_index * all_sequence_length + threadIdx.x;
} else if (mask_dimension == 3) {
mask_offset = (batch_index * sequence_length + sequence_index) * all_sequence_length + threadIdx.x;
} else if (mask_dimension == 4){
// Megatron code:
// ltor_mask = ltor_mask[..., (attention_scores.size(3)-hidden_states.size(1)):attention_scores.size(3), :attention_scores.size(3)]
} else if (mask_dimension == 4) {
mask_offset = (batch_index * max_sequence_length + all_sequence_length - sequence_length + sequence_index) * max_sequence_length + threadIdx.x;
}
const int& mask = attention_mask[mask_offset];
float mask_value = mask > 0 ? 0.0f : -10000.0f;
if (mask == 0)
thread_data += -10000.0f;
}
if (is_unidirectional) {
int from_index = all_sequence_length - sequence_length + sequence_index; // offset of from token in all sequence length.
if (threadIdx.x > from_index) {
mask_value += -10000.0f;
}
if (skip_softmax) {
if (threadIdx.x < all_sequence_length) {
output[index] = T(thread_data);
}
thread_data = float(input[index]) * scalar + mask_value;
return;
}
const float max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max(), all_sequence_length);
@ -264,7 +273,7 @@ bool ComputeSoftmax(
const int blockSize = 1024;
SoftmaxKernel<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output);
} else {
ORT_THROW("Attention CUDA operator does not support unidirectional with total sequence length > 1024.");
ORT_THROW("Attention CUDA operator does not support total sequence length > 1024.");
}
return CUDA_CALL(cudaPeekAtLastError());
@ -314,8 +323,9 @@ __global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int seq
template <typename T, unsigned TPB>
__global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, const int sequence_length, const int* attention_mask, const T* input, T* output,
const bool is_unidirectional, const float scalar, const int mask_dimension, const int max_sequence_length) {
SoftmaxWithRawMaskSmall<T, TPB>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length);
const bool is_unidirectional, const float rsqrt_head_size, const int mask_dimension, const int max_sequence_length,
const bool skip_softmax) {
SoftmaxWithRawMaskSmall<T, TPB>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, skip_softmax);
}
template <typename T>
@ -352,7 +362,7 @@ bool ComputeSoftmaxWithMask1D(cudaStream_t stream, const int all_sequence_length
MaskedSoftmaxKernel<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output);
} else {
ORT_THROW("Attention CUDA operator does not support unidirectional with total sequence length > 1024.");
ORT_THROW("Attention CUDA operator does not support total sequence length > 1024.");
}
return CUDA_CALL(cudaPeekAtLastError());
@ -360,36 +370,42 @@ bool ComputeSoftmaxWithMask1D(cudaStream_t stream, const int all_sequence_length
template <typename T>
bool ComputeSoftmaxWithRawMask(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float scalar,
const int mask_dimension, const int max_sequence_length) {
const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float rsqrt_head_size,
const int mask_dimension, const int max_sequence_length,
const bool use_persistent_softmax, T* persistent_softmax_workspace) {
const dim3 grid(sequence_length * num_heads, batch_size, 1);
T* out = use_persistent_softmax ? persistent_softmax_workspace : output;
if (all_sequence_length <= 32) {
const int blockSize = 32;
SoftmaxWithRawMaskSmallKernel<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length);
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
} else if (all_sequence_length <= 64) {
const int blockSize = 64;
SoftmaxWithRawMaskSmallKernel<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length);
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
} else if (all_sequence_length <= 128) {
const int blockSize = 128;
SoftmaxWithRawMaskSmallKernel<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length);
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
} else if (all_sequence_length <= 256) {
const int blockSize = 256;
SoftmaxWithRawMaskSmallKernel<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length);
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
} else if (all_sequence_length <= 512) {
const int blockSize = 512;
SoftmaxWithRawMaskSmallKernel<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length);
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
} else if (all_sequence_length <= 1024) {
const int blockSize = 1024;
SoftmaxWithRawMaskSmallKernel<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length);
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
} else {
ORT_THROW("Attention CUDA operator does not supported 2D attention mask with total sequence length > 1024.");
ORT_THROW("Attention CUDA operator does not support total sequence length > 1024.");
}
if (use_persistent_softmax) {
dispatch_softmax_forward<T, T, float, false>(stream, output, persistent_softmax_workspace, all_sequence_length, all_sequence_length, batch_size * num_heads * sequence_length);
}
return CUDA_CALL(cudaPeekAtLastError());

View file

@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <iostream>
#include "core/providers/shared_library/provider_api.h" // Include this otherwise Windows build complains Env::Default() missing
#include "core/platform/env_var_utils.h"
#include "transformer_common.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
// The environment variable is for testing purpose only, and it might be removed in the future.
// If you need some option in production, please file a feature request.
constexpr const char* kTransformerOptions = "ORT_TRANSFORMER_OPTIONS";
// Initialize the singleton instance
TransformerOptions TransformerOptions::instance;
const TransformerOptions* TransformerOptions::GetInstance() {
if (!instance.initialized_) {
// We do not use critical section here since it is fine to initialize multiple times by different threads.
int value = ParseEnvironmentVariableWithDefault<int>(kTransformerOptions, 0);
instance.Initialize(value);
if (value > 0)
std::cout << "ORT_TRANSFORMER_OPTIONS: IsPrecisionMode=" << instance.IsPrecisionMode()
<< ",DisablePersistentSoftmax=" << instance.DisablePersistentSoftmax() << std::endl;
}
return &instance;
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace onnxruntime {
namespace contrib {
namespace cuda {
class TransformerOptions {
public:
static const TransformerOptions* GetInstance();
bool IsPrecisionMode() const { return is_precision_mode_; }
bool DisablePersistentSoftmax() const { return disable_persistent_softmax_; }
void Initialize(int value) {
is_precision_mode_ = (value & 0x01) > 0;
disable_persistent_softmax_ = (value & 0x02) > 0;
initialized_ = true;
}
private:
// Default is false. If the mode is on, prefer precision than speed.
bool is_precision_mode_{false};
// Disable persistent softmax.
bool disable_persistent_softmax_{false};
bool initialized_{false};
static TransformerOptions instance;
};
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -39,6 +39,7 @@ class FusionSkipLayerNormalization(Fusion):
if self.shape_infer_helper is not None:
if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]):
logger.debug(f"skip skiplayernorm fusion since shape of inputs ({add.input[0]}, {add.input[1]}) are not same")
return
else:
# shape_infer_helper can not handle subgraphs. Current work around is to disable skiplayernorm fusion

View file

@ -799,6 +799,7 @@ class OnnxModel:
else:
deps_to_nodes[input_name].append(node_idx)
# Note: this logic only applies to top level graph since a sub graph could use intializer from parent graph
initializer_names = [init.name for init in graph.initializer]
graph_input_names = [input.name for input in graph.input]
input_names = initializer_names + graph_input_names

View file

@ -185,9 +185,6 @@ class BertOnnxModel(OnnxModel):
return
def adjust_reshape_and_expand(self):
# Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
FusionUtils.remove_useless_reshape_nodes(self)
nodes_to_remove = []
for node in self.nodes():
if node.op_type == 'Reshape':
@ -289,7 +286,9 @@ class BertOnnxModel(OnnxModel):
if (options is None) or options.enable_embed_layer_norm:
self.fuse_embed_layer()
# Post-processing like removing extra reshape nodes.
# Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
FusionUtils.remove_useless_reshape_nodes(self)
self.postprocess()
# Bias fusion is done after postprocess to avoid extra Reshape between bias and Gelu/FastGelu/SkipLayerNormalization

View file

@ -1291,7 +1291,7 @@ TEST(AttentionTest, AttentionUnidirectional3DMask) {
1, 1};
std::vector<float> output_data = {
3.967245340f, 0.07324841f, 4.25f, 5.65f,
3.0146f, 0.1142f, 3.9834f, 5.3394f,
8.69f, -0.13f, 4.25f, 5.65f,
8.69f, -0.13f, 4.25f, 5.65f,
3.96967912f, 0.07314367f, 4.25f, 5.65f};
@ -1332,7 +1332,7 @@ TEST(AttentionTest, AttentionUnidirectionalAttentionMask) {
std::vector<int32_t> mask_index_data = {0, 1, 1, 1};
std::vector<float> output_data = {
3.967245340f, 0.07324841f, 4.25f, 5.65f,
3.0146f, 0.1142f, 3.9834f, 5.3394f,
8.69f, -0.13f, 4.25f, 5.65f,
8.69f, -0.13f, 4.25f, 5.65f,
3.96967912f, 0.07314367f, 4.25f, 5.65f};

View file

@ -0,0 +1,511 @@
# --------------------------------------------------------------------------
# Copyright 2020 The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# -------------------------------------------------------------------------
import unittest
import pytest
import torch
from torch import nn
import random
from onnx import helper
import onnx
import numpy
import os
from transformers.modeling_utils import Conv1D
DEBUG_OUTPUTS = ["qk", "norm_qk", "softmax", "attn_weights"]
class MyGPT2Attention(nn.Module):
"""
This module is modifed from Gpt2Attention of huggingface transformers v4.9.1.
Code related to crosss attention, c_proj, attn_dropout and head_mask etc are removed.
"""
def __init__(self,
max_position_embeddings=1024,
hidden_size=768,
num_attention_heads=12,
use_cache=True,
debug=False,
fix_onnx_export=True):
super().__init__()
max_positions = max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions),
dtype=torch.uint8)).view(1, 1, max_positions, max_positions),
)
self.register_buffer("masked_bias", torch.tensor(-1e4))
self.embed_dim = hidden_size
self.num_heads = num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
assert self.head_dim * self.num_heads == self.embed_dim
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
# Use random bias instead of zeros for parity test.
self.c_attn.bias = nn.Parameter(torch.normal(0.0, 0.1, (3 * self.embed_dim, )))
self.use_cache = use_cache
self.debug = debug
self.fix_onnx_export = fix_onnx_export
def _attn(self, query, key, value, attention_mask=None):
qk = torch.matmul(query, key.transpose(-1, -2))
# Torch has special handling for Div and Mul by a scalar:
# https://github.com/pytorch/pytorch/blob/5536cda19a5def9e0553b318f04d297d602ac956/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu#L52-L60
# https://github.com/pytorch/pytorch/blob/5536cda19a5def9e0553b318f04d297d602ac956/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu#L185-L194
# Modify the code to use same processing in onnx export so as to get parity result without attention fusion.
# This walkaround is not needed when attention fusion will be applied later since the subgraph will be replaced by an Attention node.
if self.fix_onnx_export and torch.onnx.is_in_onnx_export():
if qk.dtype == torch.float16:
norm_qk = qk.to(torch.float32) * (1.0 / (float(value.size(-1))**0.5))
norm_qk = norm_qk.to(torch.float16)
else:
norm_qk = qk * (1.0 / (float(value.size(-1))**0.5))
else:
norm_qk = qk / (float(value.size(-1))**0.5)
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].bool()
attn_weights = torch.where(causal_mask, norm_qk, self.masked_bias.to(norm_qk.dtype))
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
softmax = nn.Softmax(dim=-1)(attn_weights)
attn_output = torch.matmul(softmax, value)
if self.debug:
return attn_output, qk, norm_qk, softmax, attn_weights
else:
return attn_output
def _split_heads(self, tensor, num_heads, attn_head_size):
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(*new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def _merge_heads(self, tensor, num_heads, attn_head_size):
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size, )
return tensor.view(new_shape)
@staticmethod
def concat_key_value(key, value):
return torch.cat((key.unsqueeze(0), value.unsqueeze(0)), dim=0)
@staticmethod
def process_mask(attention_mask, dtype):
# Create a 4D attention mask with shape [batch_size, 1, 1, to_seq_length] from a 2D tensor mask.
attention_mask = attention_mask[:, None, None, :]
attention_mask = attention_mask.to(dtype=dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
return attention_mask
def forward(self, hidden_states, attention_mask=None, layer_past=None):
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if self.use_cache is True:
# Instead of present = (key, value), here we merge them into one tensor to be compatible with Attention operator.
present = MyGPT2Attention.concat_key_value(key, value)
else:
present = None
mask = MyGPT2Attention.process_mask(attention_mask, dtype=query.dtype) # mask processing is moved to here.
if self.debug:
attn_output, qk, norm_qk, softmax, attn_weights = self._attn(query, key, value, mask)
else:
attn_output = self._attn(query, key, value, mask)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
outputs = (attn_output, present)
if self.debug:
if "qk" in DEBUG_OUTPUTS:
outputs += (qk, )
if "norm_qk" in DEBUG_OUTPUTS:
outputs += (norm_qk, )
if "softmax" in DEBUG_OUTPUTS:
outputs += (softmax, )
if "attn_weights" in DEBUG_OUTPUTS:
outputs += (attn_weights, )
return outputs
def create_inputs(batch_size=1,
hidden_size=768,
num_attention_heads=12,
sequence_length=1,
past_sequence_length=5,
float16=False,
device=torch.device('cuda'),
padding_length=0):
float_type = torch.float16 if float16 else torch.float32
past_shape = [batch_size, num_attention_heads, past_sequence_length, int(hidden_size / num_attention_heads)]
past_key = torch.rand(past_shape, dtype=float_type, device=device)
past_value = torch.rand(past_shape, dtype=float_type, device=device)
layer_past = MyGPT2Attention.concat_key_value(past_key, past_value)
total_sequence_length = past_sequence_length + sequence_length
attention_mask = torch.ones([batch_size, total_sequence_length], dtype=torch.int32, device=device)
if padding_length > 0:
padding_position = total_sequence_length - padding_length
attention_mask[:, padding_position:] = 0
elif padding_length < 0: # mask a random position
for i in range(batch_size):
padding_position = random.randint(0, total_sequence_length - 1)
attention_mask[i, padding_position] = 0
input_hidden_states = torch.normal(mean=0.0, std=0.1,
size=(batch_size, sequence_length, hidden_size)).to(float_type).to(device)
return input_hidden_states, attention_mask, layer_past
def get_output_names(debug=False):
outputs = ["attn_output", "present"]
if debug:
outputs += DEBUG_OUTPUTS
return outputs
def export_onnx(model, onnx_model_path, float16, hidden_size, num_attention_heads, debug, device):
from pathlib import Path
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
input_hidden_states, attention_mask, layer_past = create_inputs(float16=float16,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
device=device)
with torch.no_grad():
outputs = model(input_hidden_states, attention_mask=attention_mask, layer_past=layer_past)
dynamic_axes = {
'input_hidden_states': {
0: 'batch_size',
1: 'seq_len'
},
"attn_output": {
0: 'batch_size',
1: 'seq_len'
},
"past": {
1: 'batch_size',
3: 'past_seq_len'
},
"present": {
1: 'batch_size',
3: 'total_seq_len'
},
"attention_mask": {
0: 'batch_size',
1: 'total_seq_len'
}
}
if debug:
debug_dynamic_axes = {
"qk": {
0: 'batch_size',
1: 'seq_len'
},
"norm_qk": {
0: 'batch_size',
1: 'seq_len'
},
"softmax": {
0: 'batch_size',
1: 'seq_len'
},
"attn_weights": {
0: 'batch_size',
1: 'seq_len'
}
}
for name in DEBUG_OUTPUTS:
dynamic_axes[name] = debug_dynamic_axes[name]
torch.onnx.export(model,
args=(input_hidden_states, {
'attention_mask': attention_mask,
'layer_past': layer_past
}),
f=onnx_model_path,
input_names=['input_hidden_states', 'attention_mask', 'past'],
output_names=get_output_names(debug),
dynamic_axes=dynamic_axes,
example_outputs=outputs,
opset_version=11,
do_constant_folding=True)
print("exported:", onnx_model_path)
def optimize_onnx(input_onnx_path, optimized_onnx_path, num_heads, debug):
from onnxruntime.transformers.onnx_model import OnnxModel
m = onnx.load(input_onnx_path)
onnx_model = OnnxModel(m)
nodes_to_remove = onnx_model.nodes()
output_names = ["attn_output", "present"] + DEBUG_OUTPUTS if debug else ["attn_output", "present"]
node_to_add = helper.make_node("Attention",
["input_hidden_states", "c_attn.weight", "c_attn.bias", "attention_mask", "past"],
output_names,
"gpt2_attention",
num_heads=num_heads,
unidirectional=1,
domain="com.microsoft")
onnx_model.remove_nodes(nodes_to_remove)
onnx_model.add_node(node_to_add)
onnx_model.prune_graph()
onnx_model.save_model_to_file(optimized_onnx_path)
def diff_outputs(torch_outputs, ort_outputs, index, relative=False):
""" Returns the maximum difference between PyTorch and OnnxRuntime outputs.
"""
expected_outputs = torch_outputs[index].cpu().numpy()
diff = numpy.abs(expected_outputs - ort_outputs[index])
if relative:
return numpy.amax(diff / (numpy.abs(expected_outputs) + 1e-6))
else:
return numpy.amax(diff)
def compare_outputs(torch_outputs, ort_outputs, rtol=1e-03, atol=1e-03, verbose=False):
""" Returns True if torch and ORT outputs are close for given thresholds, and False otherwise.
"""
is_all_close = True
max_abs_diff = []
for i in range(len(ort_outputs)):
is_close = numpy.allclose(ort_outputs[i], torch_outputs[i].cpu().numpy(), rtol=rtol, atol=atol)
if not is_close:
is_all_close = False
if verbose:
print(f'output {i} ({ort_outputs[i].name}) are close: {is_close}')
max_abs_diff.append(diff_outputs(torch_outputs, ort_outputs, i))
if (not is_all_close) or (verbose and max(max_abs_diff) > 0):
messages = ["max_abs_diff per output:"]
output_names = ["attn_output", "present"] + DEBUG_OUTPUTS
for i, diff in enumerate(max_abs_diff):
messages.append(f"{output_names[i]}={diff},")
print(" ".join(messages))
return is_all_close, max(max_abs_diff)
def create_ort_session(onnx_model_path, use_gpu=True):
from onnxruntime import SessionOptions, InferenceSession, GraphOptimizationLevel, __version__ as onnxruntime_version
sess_options = SessionOptions()
sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.intra_op_num_threads = 2
sess_options.log_severity_level = 2
execution_providers = ['CPUExecutionProvider'
] if not use_gpu else ['CUDAExecutionProvider', 'CPUExecutionProvider']
return InferenceSession(onnx_model_path, sess_options, providers=execution_providers)
def onnxruntime_inference(ort_session, input_hidden_states, attention_mask, past):
ort_inputs = {
'past': numpy.ascontiguousarray(past.cpu().numpy()),
'attention_mask': numpy.ascontiguousarray(attention_mask.cpu().numpy()),
'input_hidden_states': numpy.ascontiguousarray(input_hidden_states.cpu().numpy()),
}
ort_outputs = ort_session.run(None, ort_inputs)
return ort_outputs
def verify_attention(model,
onnx_model_path,
batch_size,
hidden_size,
num_attention_heads,
sequence_length,
past_sequence_length,
float16,
device,
padding_length,
optimized,
test_cases=100):
print(
f"optimized={optimized}, batch_size={batch_size}, hidden_size={hidden_size}, num_attention_heads={num_attention_heads}, sequence_length={sequence_length}, past_sequence_length={past_sequence_length}, float16={float16}, padding_length={padding_length}, device={device}"
)
passed_cases = 0
max_diffs = []
ort_session = create_ort_session(onnx_model_path, device.type == 'cuda')
for i in range(test_cases):
input_hidden_states, attention_mask, layer_past = create_inputs(batch_size, hidden_size, num_attention_heads,
sequence_length, past_sequence_length, float16,
device, padding_length)
with torch.no_grad():
torch_outputs = model(input_hidden_states, layer_past=layer_past, attention_mask=attention_mask)
ort_outputs = onnxruntime_inference(ort_session, input_hidden_states, attention_mask, layer_past)
tolerance = 1e-03 if float16 else 1e-05
is_all_close, max_diff = compare_outputs(torch_outputs, ort_outputs, rtol=tolerance, atol=tolerance)
max_diffs.append(max_diff)
if is_all_close:
passed_cases += 1
max_diff = max(max_diffs)
diff_count = len([i for i in max_diffs if i > 0])
success_flag = "[FAILED]" if passed_cases < test_cases else "[OK]"
print(f"{success_flag} Passed_cases={passed_cases}/{test_cases}; Max_diff={max_diff}; Diff_count={diff_count}")
return test_cases - passed_cases
def run(batch_size, float16, optimized, hidden_size, num_attention_heads, device, test_cases):
test_name = f"batch_size={batch_size}, float16={float16}, optimized={optimized}, hidden_size={hidden_size}, num_attention_heads={num_attention_heads}"
print(f"\nTesting ONNX parity: {test_name}")
debug = (not optimized
) # or DEBUG_OUTPUTS==["softmax"] when you add an extra output for softmax result in Attention operator
model = MyGPT2Attention(hidden_size=hidden_size, num_attention_heads=num_attention_heads, debug=debug)
model.eval()
model.to(device)
if float16:
model.half()
# Do not re-use onnx file from previous test since weights of model are random.
onnx_model_path = './temp/gpt_attention_{}.onnx'.format("fp16" if float16 else "fp32")
export_onnx(model, onnx_model_path, float16, hidden_size, num_attention_heads, debug, device)
if optimized:
optimized_onnx_path = './temp/gpt_attention_opt_{}.onnx'.format("fp16" if float16 else "fp32")
optimize_onnx(onnx_model_path, optimized_onnx_path, num_attention_heads, debug)
onnx_path = optimized_onnx_path
else:
onnx_path = onnx_model_path
# Test Case: No past state
sequence_length = 2
past_sequence_length = 0
padding_length = 0
num_failure = 0
num_failure += verify_attention(model,
onnx_path,
batch_size,
hidden_size,
num_attention_heads,
sequence_length,
past_sequence_length,
float16,
device,
padding_length,
optimized,
test_cases)
# Test Case: with past state and padding last 2 words
sequence_length = 3
past_sequence_length = 5
padding_length = 2
num_failure += verify_attention(model,
onnx_path,
batch_size,
hidden_size,
num_attention_heads,
sequence_length,
past_sequence_length,
float16,
device,
padding_length,
optimized,
test_cases)
# Test Case: random mask one word
sequence_length = 1
past_sequence_length = 128
padding_length = -1
num_failure += verify_attention(model,
onnx_path,
batch_size,
hidden_size,
num_attention_heads,
sequence_length,
past_sequence_length,
float16,
device,
padding_length,
optimized,
test_cases)
# clean up onnx file
os.remove(onnx_model_path)
if optimized:
os.remove(onnx_path)
return num_failure, test_name
class TestGptAttentionHuggingfaceParity(unittest.TestCase):
def setUp(self):
self.optimized = True # Change it to False if you want to test parity of non optimized ONNX
self.test_cases = 10 # Number of test cases per test run
def run_test(self, batch_size, float16, optimized, hidden_size, num_attention_heads, device):
if float16 and device.type=='cpu': # CPU does not support FP16
return
num_failure, test_name = run(batch_size, float16, optimized, hidden_size, num_attention_heads, device, self.test_cases)
self.assertTrue(num_failure == 0, test_name)
def run_small(self, optimized, device):
for batch_size in [64]:
self.run_test(batch_size, float16=False, optimized=optimized, hidden_size=768, num_attention_heads=12, device=device)
self.run_test(batch_size, float16=True, optimized=optimized, hidden_size=768, num_attention_heads=12, device=device)
def run_large(self, optimized, device):
for batch_size in [2]:
self.run_test(batch_size, float16=False, optimized=optimized, hidden_size=4096, num_attention_heads=32, device=device)
self.run_test(batch_size, float16=True, optimized=optimized, hidden_size=4096, num_attention_heads=32, device=device)
def test_cpu(self):
cpu = torch.device('cpu')
self.run_small(self.optimized, cpu)
def test_cuda(self):
if not torch.cuda.is_available():
import pytest
pytest.skip('test requires GPU and torch+cuda')
else:
gpu = torch.device('cuda')
self.run_small(self.optimized, gpu)
@pytest.mark.slow
def test_large_cuda(self):
if not torch.cuda.is_available():
import pytest
pytest.skip('test requires GPU and torch+cuda')
else:
gpu = torch.device('cuda')
self.run_large(self.optimized, gpu)
if __name__ == '__main__':
unittest.main()