mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-30 23:18:20 +00:00
Fix attention parity for GPT-2 (#8549)
* Use persistent softmax to parity with huggingface * fix undirectional mask logic * add test
This commit is contained in:
parent
816ad86d14
commit
330b8e74bd
11 changed files with 680 additions and 74 deletions
|
|
@ -112,11 +112,12 @@ class AttentionCPUBase : public AttentionBase {
|
|||
|
||||
{
|
||||
if (mask_data != nullptr) {
|
||||
PrepareMask(mask_index, mask_index_dims, mask_data, is_unidirectional_, batch_size, sequence_length, past_sequence_length);
|
||||
} else { // no any mask
|
||||
memset(attention_probs, 0, static_cast<size_t>(batch_size) * num_heads_ * sequence_length * all_sequence_length * sizeof(T));
|
||||
// Convert attention mask data from int to float (0 to -10000.0f). The mask_data shape is BxSxS*.
|
||||
PrepareMask(mask_index, mask_index_dims, mask_data, batch_size, sequence_length, past_sequence_length);
|
||||
}
|
||||
|
||||
memset(attention_probs, 0, static_cast<size_t>(batch_size) * num_heads_ * sequence_length * all_sequence_length * sizeof(T));
|
||||
|
||||
const int loop_len = batch_size * num_heads_;
|
||||
const float alpha = 1.0f / sqrt(static_cast<float>(head_size));
|
||||
|
||||
|
|
@ -127,32 +128,45 @@ class AttentionCPUBase : public AttentionBase {
|
|||
for (std::ptrdiff_t i = begin; i != end; ++i) {
|
||||
const std::ptrdiff_t batch_index = i / num_heads_;
|
||||
|
||||
// broadcast mask data: (Bx)SxS* -> (BxNx)SxS*
|
||||
if (mask_data != nullptr) {
|
||||
const T* broadcast_data_src = reinterpret_cast<T*>(mask_data) + batch_index * sequence_length * all_sequence_length;
|
||||
T* broadcast_data_dest = reinterpret_cast<T*>(attention_probs) + sequence_length * all_sequence_length * i;
|
||||
memcpy(broadcast_data_dest, broadcast_data_src, sequence_length * all_sequence_length * sizeof(T));
|
||||
}
|
||||
|
||||
const T* k = K + input_chunk_length * i;
|
||||
if (nullptr != present) {
|
||||
// concatenate past_K and K : (BxNx)S'xH, (BxNx)SxH -> (BxNx)S*xH
|
||||
// Concatenate past_K and K : (BxNx)S'xH, (BxNx)SxH -> (BxNx)S*xH
|
||||
k = ConcatStateChunk(past, k, present, past_chunk_length, present_chunk_length, i);
|
||||
}
|
||||
|
||||
// gemm
|
||||
int offset = sequence_length * all_sequence_length * static_cast<int>(i);
|
||||
T* output = reinterpret_cast<T*>(attention_probs) + offset;
|
||||
|
||||
// Compute Q*K'
|
||||
// original transposed each iteration
|
||||
// A: Q (B x N x) S x H (B x N x) S x H S x H
|
||||
// B: K' (B x N x) S* x H (B x N x) H x S* H x S*
|
||||
// C: attention_probs (B x N x) S x S* (B x N x) S x S* S x S*
|
||||
math::Gemm<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, all_sequence_length, head_size, alpha,
|
||||
Q + input_chunk_length * i, k, 1.0,
|
||||
reinterpret_cast<T*>(attention_probs) + sequence_length * all_sequence_length * i, nullptr);
|
||||
output, nullptr);
|
||||
|
||||
// Apply unidirectional mask and set future words to -10000.0f.
|
||||
if (is_unidirectional_) {
|
||||
for (int s = 0; s < sequence_length - 1; s++) {
|
||||
for (int t = past_sequence_length + s + 1; t < all_sequence_length; t++) {
|
||||
output[s * all_sequence_length + t] = static_cast<T>(-10000.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply attention mask
|
||||
if (mask_data != nullptr) {
|
||||
const T* attention_mask = reinterpret_cast<T*>(mask_data) + batch_index * sequence_length * all_sequence_length;
|
||||
for (int j = 0; j < sequence_length * all_sequence_length ; j++) {
|
||||
output[j] += attention_mask[j];
|
||||
}
|
||||
}
|
||||
|
||||
if (extra_add_qk_data != nullptr) {
|
||||
int extra_add_qk_offset = static_cast<int>(i) * sequence_length * all_sequence_length;
|
||||
const T* extra = extra_add_qk_data + offset;
|
||||
for (int j = 0; j < sequence_length * all_sequence_length ; j++) {
|
||||
attention_probs[extra_add_qk_offset+j] += extra_add_qk_data[extra_add_qk_offset + j];
|
||||
output[j] += extra[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -64,7 +64,6 @@ template <typename T>
|
|||
void PrepareMask(const int32_t* mask_index,
|
||||
const std::vector<int64_t>* mask_index_dims,
|
||||
T* mask_data,
|
||||
bool is_unidirectional,
|
||||
int batch_size,
|
||||
int sequence_length,
|
||||
int past_sequence_length) {
|
||||
|
|
@ -84,18 +83,6 @@ void PrepareMask(const int32_t* mask_index,
|
|||
for (int i = 0; i < batch_size * sequence_length * all_sequence_length; i++) {
|
||||
p_mask[i] = (mask_index[i] > 0) ? static_cast<T>(0.0f) : static_cast<T>(-10000.0f);
|
||||
}
|
||||
|
||||
if (is_unidirectional) {
|
||||
for (int b_i = 0; b_i < batch_size; b_i++) {
|
||||
for (int s_i = 0; s_i < sequence_length - 1; s_i++) {
|
||||
for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) {
|
||||
p_mask[s_i * all_sequence_length + m_i] += static_cast<T>(-10000.0f);
|
||||
}
|
||||
}
|
||||
p_mask += sequence_length * all_sequence_length;
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -135,15 +122,6 @@ void PrepareMask(const int32_t* mask_index,
|
|||
memcpy(p_mask + s_i * all_sequence_length, p_mask, all_sequence_length * sizeof(T));
|
||||
}
|
||||
|
||||
// Apply unidirectional mask.
|
||||
if (is_unidirectional) {
|
||||
for (int s_i = 0; s_i < sequence_length - 1; s_i++) {
|
||||
for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) {
|
||||
p_mask[s_i * all_sequence_length + m_i] += static_cast<T>(-10000.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p_mask += sequence_length * all_sequence_length;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||
#include "core/providers/cuda/shared_inc/fpgeneric.h"
|
||||
#include "attention_impl.h"
|
||||
#include "attention_softmax.h"
|
||||
#include "transformer_common.h"
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
using namespace cub;
|
||||
|
|
@ -65,7 +66,7 @@ bool QkvToContext(
|
|||
const int batch_size, const int sequence_length, const int num_heads, const int head_size, const size_t element_size,
|
||||
const T* input, T* output, T* workspace,
|
||||
const int* mask_index, const std::vector<int64_t>* mask_index_dims,
|
||||
bool is_unidirectional, int past_sequence_length, const T* past, T* present) {
|
||||
bool is_unidirectional, int past_sequence_length, const T* past, T* present, bool use_persistent_softmax) {
|
||||
const int all_sequence_length = past_sequence_length + sequence_length;
|
||||
const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, all_sequence_length);
|
||||
T* scratch1 = workspace;
|
||||
|
|
@ -127,8 +128,11 @@ bool QkvToContext(
|
|||
if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask
|
||||
const int mask_dimension = static_cast<int>(mask_index_dims->size());
|
||||
const int64_t max_sequence_length = mask_dimension == 4 ? mask_index_dims->at(3) : 0;
|
||||
|
||||
T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score if persistent softmax is selected.
|
||||
if (!ComputeSoftmaxWithRawMask<T>(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, scratch1, scratch2, is_unidirectional,
|
||||
rsqrt_head_size, mask_dimension, static_cast<int>(max_sequence_length))) {
|
||||
rsqrt_head_size, mask_dimension, static_cast<int>(max_sequence_length),
|
||||
use_persistent_softmax, persistent_softmax_workspace)) {
|
||||
return false;
|
||||
}
|
||||
} else if (nullptr != mask_index) { // 1d mask index
|
||||
|
|
@ -173,18 +177,27 @@ bool LaunchAttentionKernel(
|
|||
int past_sequence_length,
|
||||
const void* past,
|
||||
void* present) {
|
||||
|
||||
// GPT-2 model is more sensitive on parity since error will accumulate in text generation.
|
||||
// So use persistent softmax for GPT-2 model by default.
|
||||
// For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 or 2 could enable or disable it explicitly.
|
||||
const TransformerOptions* options = TransformerOptions::GetInstance();
|
||||
bool use_persistent_softmax = (is_unidirectional || options->IsPrecisionMode()) && !options->DisablePersistentSoftmax();
|
||||
|
||||
if (element_size == 2) {
|
||||
return QkvToContext(prop, cublas, stream,
|
||||
batch_size, sequence_length, num_heads, head_size, element_size,
|
||||
reinterpret_cast<const half*>(input), reinterpret_cast<half*>(output), reinterpret_cast<half*>(workspace),
|
||||
mask_index, mask_index_dims, is_unidirectional,
|
||||
past_sequence_length, reinterpret_cast<const half*>(past), reinterpret_cast<half*>(present));
|
||||
past_sequence_length, reinterpret_cast<const half*>(past), reinterpret_cast<half*>(present),
|
||||
use_persistent_softmax);
|
||||
} else {
|
||||
return QkvToContext(prop, cublas, stream,
|
||||
batch_size, sequence_length, num_heads, head_size, element_size,
|
||||
reinterpret_cast<const float*>(input), reinterpret_cast<float*>(output), reinterpret_cast<float*>(workspace),
|
||||
mask_index, mask_index_dims, is_unidirectional,
|
||||
past_sequence_length, reinterpret_cast<const float*>(past), reinterpret_cast<float*>(present));
|
||||
past_sequence_length, reinterpret_cast<const float*>(past), reinterpret_cast<float*>(present),
|
||||
use_persistent_softmax);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -19,11 +19,13 @@ limitations under the License.
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include <cub/cub.cuh>
|
||||
#include <cuda_fp16.h>
|
||||
#include <math_constants.h>
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/math/softmax.h"
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
using namespace cub;
|
||||
|
|
@ -165,9 +167,10 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
|
|||
const T* input,
|
||||
T* output,
|
||||
const bool is_unidirectional,
|
||||
const float scalar,
|
||||
const float rsqrt_head_size,
|
||||
const int mask_dimension,
|
||||
const int max_sequence_length) {
|
||||
const int max_sequence_length,
|
||||
const bool skip_softmax) {
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp_storage;
|
||||
|
||||
|
|
@ -179,30 +182,36 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
|
|||
|
||||
float thread_data = -CUDART_INF_F;
|
||||
if (threadIdx.x < all_sequence_length) {
|
||||
const int batch_index = blockIdx.y;
|
||||
thread_data = float(input[index]) * rsqrt_head_size;
|
||||
|
||||
const int sequence_index = blockIdx.x % sequence_length;
|
||||
if (is_unidirectional) {
|
||||
int from_index = all_sequence_length - sequence_length + sequence_index; // offset of from token in all sequence length.
|
||||
if (threadIdx.x > from_index) {
|
||||
thread_data = -10000.0f;
|
||||
}
|
||||
}
|
||||
|
||||
int mask_offset = 0;
|
||||
const int batch_index = blockIdx.y;
|
||||
if (mask_dimension == 2) {
|
||||
mask_offset = batch_index * all_sequence_length + threadIdx.x;
|
||||
} else if (mask_dimension == 3) {
|
||||
mask_offset = (batch_index * sequence_length + sequence_index) * all_sequence_length + threadIdx.x;
|
||||
} else if (mask_dimension == 4){
|
||||
// Megatron code:
|
||||
// ltor_mask = ltor_mask[..., (attention_scores.size(3)-hidden_states.size(1)):attention_scores.size(3), :attention_scores.size(3)]
|
||||
} else if (mask_dimension == 4) {
|
||||
mask_offset = (batch_index * max_sequence_length + all_sequence_length - sequence_length + sequence_index) * max_sequence_length + threadIdx.x;
|
||||
}
|
||||
|
||||
const int& mask = attention_mask[mask_offset];
|
||||
float mask_value = mask > 0 ? 0.0f : -10000.0f;
|
||||
if (mask == 0)
|
||||
thread_data += -10000.0f;
|
||||
}
|
||||
|
||||
if (is_unidirectional) {
|
||||
int from_index = all_sequence_length - sequence_length + sequence_index; // offset of from token in all sequence length.
|
||||
if (threadIdx.x > from_index) {
|
||||
mask_value += -10000.0f;
|
||||
}
|
||||
if (skip_softmax) {
|
||||
if (threadIdx.x < all_sequence_length) {
|
||||
output[index] = T(thread_data);
|
||||
}
|
||||
|
||||
thread_data = float(input[index]) * scalar + mask_value;
|
||||
return;
|
||||
}
|
||||
|
||||
const float max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max(), all_sequence_length);
|
||||
|
|
@ -264,7 +273,7 @@ bool ComputeSoftmax(
|
|||
const int blockSize = 1024;
|
||||
SoftmaxKernel<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output);
|
||||
} else {
|
||||
ORT_THROW("Attention CUDA operator does not support unidirectional with total sequence length > 1024.");
|
||||
ORT_THROW("Attention CUDA operator does not support total sequence length > 1024.");
|
||||
}
|
||||
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
|
|
@ -314,8 +323,9 @@ __global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int seq
|
|||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, const int sequence_length, const int* attention_mask, const T* input, T* output,
|
||||
const bool is_unidirectional, const float scalar, const int mask_dimension, const int max_sequence_length) {
|
||||
SoftmaxWithRawMaskSmall<T, TPB>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length);
|
||||
const bool is_unidirectional, const float rsqrt_head_size, const int mask_dimension, const int max_sequence_length,
|
||||
const bool skip_softmax) {
|
||||
SoftmaxWithRawMaskSmall<T, TPB>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, skip_softmax);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -352,7 +362,7 @@ bool ComputeSoftmaxWithMask1D(cudaStream_t stream, const int all_sequence_length
|
|||
MaskedSoftmaxKernel<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output);
|
||||
} else {
|
||||
ORT_THROW("Attention CUDA operator does not support unidirectional with total sequence length > 1024.");
|
||||
ORT_THROW("Attention CUDA operator does not support total sequence length > 1024.");
|
||||
}
|
||||
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
|
|
@ -360,36 +370,42 @@ bool ComputeSoftmaxWithMask1D(cudaStream_t stream, const int all_sequence_length
|
|||
|
||||
template <typename T>
|
||||
bool ComputeSoftmaxWithRawMask(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
|
||||
const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float scalar,
|
||||
const int mask_dimension, const int max_sequence_length) {
|
||||
const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float rsqrt_head_size,
|
||||
const int mask_dimension, const int max_sequence_length,
|
||||
const bool use_persistent_softmax, T* persistent_softmax_workspace) {
|
||||
const dim3 grid(sequence_length * num_heads, batch_size, 1);
|
||||
|
||||
T* out = use_persistent_softmax ? persistent_softmax_workspace : output;
|
||||
if (all_sequence_length <= 32) {
|
||||
const int blockSize = 32;
|
||||
SoftmaxWithRawMaskSmallKernel<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length);
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
|
||||
} else if (all_sequence_length <= 64) {
|
||||
const int blockSize = 64;
|
||||
SoftmaxWithRawMaskSmallKernel<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length);
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
|
||||
} else if (all_sequence_length <= 128) {
|
||||
const int blockSize = 128;
|
||||
SoftmaxWithRawMaskSmallKernel<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length);
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
|
||||
} else if (all_sequence_length <= 256) {
|
||||
const int blockSize = 256;
|
||||
SoftmaxWithRawMaskSmallKernel<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length);
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
|
||||
} else if (all_sequence_length <= 512) {
|
||||
const int blockSize = 512;
|
||||
SoftmaxWithRawMaskSmallKernel<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length);
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
|
||||
} else if (all_sequence_length <= 1024) {
|
||||
const int blockSize = 1024;
|
||||
SoftmaxWithRawMaskSmallKernel<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length);
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
|
||||
} else {
|
||||
ORT_THROW("Attention CUDA operator does not supported 2D attention mask with total sequence length > 1024.");
|
||||
ORT_THROW("Attention CUDA operator does not support total sequence length > 1024.");
|
||||
}
|
||||
|
||||
if (use_persistent_softmax) {
|
||||
dispatch_softmax_forward<T, T, float, false>(stream, output, persistent_softmax_workspace, all_sequence_length, all_sequence_length, batch_size * num_heads * sequence_length);
|
||||
}
|
||||
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
|
|
|
|||
35
onnxruntime/contrib_ops/cuda/bert/transformer_common.cc
Normal file
35
onnxruntime/contrib_ops/cuda/bert/transformer_common.cc
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#include <iostream>
|
||||
#include "core/providers/shared_library/provider_api.h" // Include this otherwise Windows build complains Env::Default() missing
|
||||
#include "core/platform/env_var_utils.h"
|
||||
#include "transformer_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
// The environment variable is for testing purpose only, and it might be removed in the future.
|
||||
// If you need some option in production, please file a feature request.
|
||||
constexpr const char* kTransformerOptions = "ORT_TRANSFORMER_OPTIONS";
|
||||
|
||||
// Initialize the singleton instance
|
||||
TransformerOptions TransformerOptions::instance;
|
||||
|
||||
const TransformerOptions* TransformerOptions::GetInstance() {
|
||||
if (!instance.initialized_) {
|
||||
// We do not use critical section here since it is fine to initialize multiple times by different threads.
|
||||
int value = ParseEnvironmentVariableWithDefault<int>(kTransformerOptions, 0);
|
||||
instance.Initialize(value);
|
||||
|
||||
if (value > 0)
|
||||
std::cout << "ORT_TRANSFORMER_OPTIONS: IsPrecisionMode=" << instance.IsPrecisionMode()
|
||||
<< ",DisablePersistentSoftmax=" << instance.DisablePersistentSoftmax() << std::endl;
|
||||
}
|
||||
|
||||
return &instance;
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
38
onnxruntime/contrib_ops/cuda/bert/transformer_common.h
Normal file
38
onnxruntime/contrib_ops/cuda/bert/transformer_common.h
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
class TransformerOptions {
|
||||
public:
|
||||
static const TransformerOptions* GetInstance();
|
||||
|
||||
bool IsPrecisionMode() const { return is_precision_mode_; }
|
||||
|
||||
bool DisablePersistentSoftmax() const { return disable_persistent_softmax_; }
|
||||
|
||||
void Initialize(int value) {
|
||||
is_precision_mode_ = (value & 0x01) > 0;
|
||||
disable_persistent_softmax_ = (value & 0x02) > 0;
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
private:
|
||||
// Default is false. If the mode is on, prefer precision than speed.
|
||||
bool is_precision_mode_{false};
|
||||
|
||||
// Disable persistent softmax.
|
||||
bool disable_persistent_softmax_{false};
|
||||
|
||||
bool initialized_{false};
|
||||
|
||||
static TransformerOptions instance;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -39,6 +39,7 @@ class FusionSkipLayerNormalization(Fusion):
|
|||
|
||||
if self.shape_infer_helper is not None:
|
||||
if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]):
|
||||
logger.debug(f"skip skiplayernorm fusion since shape of inputs ({add.input[0]}, {add.input[1]}) are not same")
|
||||
return
|
||||
else:
|
||||
# shape_infer_helper can not handle subgraphs. Current work around is to disable skiplayernorm fusion
|
||||
|
|
|
|||
|
|
@ -799,6 +799,7 @@ class OnnxModel:
|
|||
else:
|
||||
deps_to_nodes[input_name].append(node_idx)
|
||||
|
||||
# Note: this logic only applies to top level graph since a sub graph could use intializer from parent graph
|
||||
initializer_names = [init.name for init in graph.initializer]
|
||||
graph_input_names = [input.name for input in graph.input]
|
||||
input_names = initializer_names + graph_input_names
|
||||
|
|
|
|||
|
|
@ -185,9 +185,6 @@ class BertOnnxModel(OnnxModel):
|
|||
return
|
||||
|
||||
def adjust_reshape_and_expand(self):
|
||||
# Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
|
||||
FusionUtils.remove_useless_reshape_nodes(self)
|
||||
|
||||
nodes_to_remove = []
|
||||
for node in self.nodes():
|
||||
if node.op_type == 'Reshape':
|
||||
|
|
@ -289,7 +286,9 @@ class BertOnnxModel(OnnxModel):
|
|||
if (options is None) or options.enable_embed_layer_norm:
|
||||
self.fuse_embed_layer()
|
||||
|
||||
# Post-processing like removing extra reshape nodes.
|
||||
# Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
|
||||
FusionUtils.remove_useless_reshape_nodes(self)
|
||||
|
||||
self.postprocess()
|
||||
|
||||
# Bias fusion is done after postprocess to avoid extra Reshape between bias and Gelu/FastGelu/SkipLayerNormalization
|
||||
|
|
|
|||
|
|
@ -1291,7 +1291,7 @@ TEST(AttentionTest, AttentionUnidirectional3DMask) {
|
|||
1, 1};
|
||||
|
||||
std::vector<float> output_data = {
|
||||
3.967245340f, 0.07324841f, 4.25f, 5.65f,
|
||||
3.0146f, 0.1142f, 3.9834f, 5.3394f,
|
||||
8.69f, -0.13f, 4.25f, 5.65f,
|
||||
8.69f, -0.13f, 4.25f, 5.65f,
|
||||
3.96967912f, 0.07314367f, 4.25f, 5.65f};
|
||||
|
|
@ -1332,7 +1332,7 @@ TEST(AttentionTest, AttentionUnidirectionalAttentionMask) {
|
|||
std::vector<int32_t> mask_index_data = {0, 1, 1, 1};
|
||||
|
||||
std::vector<float> output_data = {
|
||||
3.967245340f, 0.07324841f, 4.25f, 5.65f,
|
||||
3.0146f, 0.1142f, 3.9834f, 5.3394f,
|
||||
8.69f, -0.13f, 4.25f, 5.65f,
|
||||
8.69f, -0.13f, 4.25f, 5.65f,
|
||||
3.96967912f, 0.07314367f, 4.25f, 5.65f};
|
||||
|
|
|
|||
|
|
@ -0,0 +1,511 @@
|
|||
# --------------------------------------------------------------------------
|
||||
# Copyright 2020 The HuggingFace Inc. team
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
# --------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
import unittest
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
import random
|
||||
from onnx import helper
|
||||
import onnx
|
||||
import numpy
|
||||
import os
|
||||
from transformers.modeling_utils import Conv1D
|
||||
|
||||
DEBUG_OUTPUTS = ["qk", "norm_qk", "softmax", "attn_weights"]
|
||||
|
||||
|
||||
class MyGPT2Attention(nn.Module):
|
||||
"""
|
||||
This module is modifed from Gpt2Attention of huggingface transformers v4.9.1.
|
||||
Code related to crosss attention, c_proj, attn_dropout and head_mask etc are removed.
|
||||
"""
|
||||
def __init__(self,
|
||||
max_position_embeddings=1024,
|
||||
hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
use_cache=True,
|
||||
debug=False,
|
||||
fix_onnx_export=True):
|
||||
super().__init__()
|
||||
max_positions = max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions),
|
||||
dtype=torch.uint8)).view(1, 1, max_positions, max_positions),
|
||||
)
|
||||
self.register_buffer("masked_bias", torch.tensor(-1e4))
|
||||
self.embed_dim = hidden_size
|
||||
self.num_heads = num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.split_size = self.embed_dim
|
||||
assert self.head_dim * self.num_heads == self.embed_dim
|
||||
|
||||
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
|
||||
# Use random bias instead of zeros for parity test.
|
||||
self.c_attn.bias = nn.Parameter(torch.normal(0.0, 0.1, (3 * self.embed_dim, )))
|
||||
|
||||
self.use_cache = use_cache
|
||||
self.debug = debug
|
||||
self.fix_onnx_export = fix_onnx_export
|
||||
|
||||
def _attn(self, query, key, value, attention_mask=None):
|
||||
qk = torch.matmul(query, key.transpose(-1, -2))
|
||||
|
||||
# Torch has special handling for Div and Mul by a scalar:
|
||||
# https://github.com/pytorch/pytorch/blob/5536cda19a5def9e0553b318f04d297d602ac956/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu#L52-L60
|
||||
# https://github.com/pytorch/pytorch/blob/5536cda19a5def9e0553b318f04d297d602ac956/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu#L185-L194
|
||||
# Modify the code to use same processing in onnx export so as to get parity result without attention fusion.
|
||||
# This walkaround is not needed when attention fusion will be applied later since the subgraph will be replaced by an Attention node.
|
||||
if self.fix_onnx_export and torch.onnx.is_in_onnx_export():
|
||||
if qk.dtype == torch.float16:
|
||||
norm_qk = qk.to(torch.float32) * (1.0 / (float(value.size(-1))**0.5))
|
||||
norm_qk = norm_qk.to(torch.float16)
|
||||
else:
|
||||
norm_qk = qk * (1.0 / (float(value.size(-1))**0.5))
|
||||
else:
|
||||
norm_qk = qk / (float(value.size(-1))**0.5)
|
||||
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].bool()
|
||||
attn_weights = torch.where(causal_mask, norm_qk, self.masked_bias.to(norm_qk.dtype))
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
softmax = nn.Softmax(dim=-1)(attn_weights)
|
||||
|
||||
attn_output = torch.matmul(softmax, value)
|
||||
|
||||
if self.debug:
|
||||
return attn_output, qk, norm_qk, softmax, attn_weights
|
||||
else:
|
||||
return attn_output
|
||||
|
||||
def _split_heads(self, tensor, num_heads, attn_head_size):
|
||||
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
||||
tensor = tensor.view(*new_shape)
|
||||
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
||||
|
||||
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
||||
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
||||
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size, )
|
||||
return tensor.view(new_shape)
|
||||
|
||||
@staticmethod
|
||||
def concat_key_value(key, value):
|
||||
return torch.cat((key.unsqueeze(0), value.unsqueeze(0)), dim=0)
|
||||
|
||||
@staticmethod
|
||||
def process_mask(attention_mask, dtype):
|
||||
# Create a 4D attention mask with shape [batch_size, 1, 1, to_seq_length] from a 2D tensor mask.
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
attention_mask = attention_mask.to(dtype=dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, layer_past=None):
|
||||
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
||||
|
||||
query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
key = torch.cat((past_key, key), dim=-2)
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
|
||||
if self.use_cache is True:
|
||||
# Instead of present = (key, value), here we merge them into one tensor to be compatible with Attention operator.
|
||||
present = MyGPT2Attention.concat_key_value(key, value)
|
||||
else:
|
||||
present = None
|
||||
|
||||
mask = MyGPT2Attention.process_mask(attention_mask, dtype=query.dtype) # mask processing is moved to here.
|
||||
|
||||
if self.debug:
|
||||
attn_output, qk, norm_qk, softmax, attn_weights = self._attn(query, key, value, mask)
|
||||
else:
|
||||
attn_output = self._attn(query, key, value, mask)
|
||||
|
||||
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
if self.debug:
|
||||
if "qk" in DEBUG_OUTPUTS:
|
||||
outputs += (qk, )
|
||||
if "norm_qk" in DEBUG_OUTPUTS:
|
||||
outputs += (norm_qk, )
|
||||
if "softmax" in DEBUG_OUTPUTS:
|
||||
outputs += (softmax, )
|
||||
if "attn_weights" in DEBUG_OUTPUTS:
|
||||
outputs += (attn_weights, )
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def create_inputs(batch_size=1,
|
||||
hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
sequence_length=1,
|
||||
past_sequence_length=5,
|
||||
float16=False,
|
||||
device=torch.device('cuda'),
|
||||
padding_length=0):
|
||||
float_type = torch.float16 if float16 else torch.float32
|
||||
|
||||
past_shape = [batch_size, num_attention_heads, past_sequence_length, int(hidden_size / num_attention_heads)]
|
||||
past_key = torch.rand(past_shape, dtype=float_type, device=device)
|
||||
past_value = torch.rand(past_shape, dtype=float_type, device=device)
|
||||
layer_past = MyGPT2Attention.concat_key_value(past_key, past_value)
|
||||
|
||||
total_sequence_length = past_sequence_length + sequence_length
|
||||
|
||||
attention_mask = torch.ones([batch_size, total_sequence_length], dtype=torch.int32, device=device)
|
||||
if padding_length > 0:
|
||||
padding_position = total_sequence_length - padding_length
|
||||
attention_mask[:, padding_position:] = 0
|
||||
elif padding_length < 0: # mask a random position
|
||||
for i in range(batch_size):
|
||||
padding_position = random.randint(0, total_sequence_length - 1)
|
||||
attention_mask[i, padding_position] = 0
|
||||
|
||||
input_hidden_states = torch.normal(mean=0.0, std=0.1,
|
||||
size=(batch_size, sequence_length, hidden_size)).to(float_type).to(device)
|
||||
return input_hidden_states, attention_mask, layer_past
|
||||
|
||||
|
||||
def get_output_names(debug=False):
|
||||
outputs = ["attn_output", "present"]
|
||||
if debug:
|
||||
outputs += DEBUG_OUTPUTS
|
||||
return outputs
|
||||
|
||||
|
||||
def export_onnx(model, onnx_model_path, float16, hidden_size, num_attention_heads, debug, device):
|
||||
from pathlib import Path
|
||||
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
input_hidden_states, attention_mask, layer_past = create_inputs(float16=float16,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(input_hidden_states, attention_mask=attention_mask, layer_past=layer_past)
|
||||
|
||||
dynamic_axes = {
|
||||
'input_hidden_states': {
|
||||
0: 'batch_size',
|
||||
1: 'seq_len'
|
||||
},
|
||||
"attn_output": {
|
||||
0: 'batch_size',
|
||||
1: 'seq_len'
|
||||
},
|
||||
"past": {
|
||||
1: 'batch_size',
|
||||
3: 'past_seq_len'
|
||||
},
|
||||
"present": {
|
||||
1: 'batch_size',
|
||||
3: 'total_seq_len'
|
||||
},
|
||||
"attention_mask": {
|
||||
0: 'batch_size',
|
||||
1: 'total_seq_len'
|
||||
}
|
||||
}
|
||||
if debug:
|
||||
debug_dynamic_axes = {
|
||||
"qk": {
|
||||
0: 'batch_size',
|
||||
1: 'seq_len'
|
||||
},
|
||||
"norm_qk": {
|
||||
0: 'batch_size',
|
||||
1: 'seq_len'
|
||||
},
|
||||
"softmax": {
|
||||
0: 'batch_size',
|
||||
1: 'seq_len'
|
||||
},
|
||||
"attn_weights": {
|
||||
0: 'batch_size',
|
||||
1: 'seq_len'
|
||||
}
|
||||
}
|
||||
for name in DEBUG_OUTPUTS:
|
||||
dynamic_axes[name] = debug_dynamic_axes[name]
|
||||
|
||||
torch.onnx.export(model,
|
||||
args=(input_hidden_states, {
|
||||
'attention_mask': attention_mask,
|
||||
'layer_past': layer_past
|
||||
}),
|
||||
f=onnx_model_path,
|
||||
input_names=['input_hidden_states', 'attention_mask', 'past'],
|
||||
output_names=get_output_names(debug),
|
||||
dynamic_axes=dynamic_axes,
|
||||
example_outputs=outputs,
|
||||
opset_version=11,
|
||||
do_constant_folding=True)
|
||||
print("exported:", onnx_model_path)
|
||||
|
||||
|
||||
def optimize_onnx(input_onnx_path, optimized_onnx_path, num_heads, debug):
|
||||
from onnxruntime.transformers.onnx_model import OnnxModel
|
||||
m = onnx.load(input_onnx_path)
|
||||
onnx_model = OnnxModel(m)
|
||||
|
||||
nodes_to_remove = onnx_model.nodes()
|
||||
output_names = ["attn_output", "present"] + DEBUG_OUTPUTS if debug else ["attn_output", "present"]
|
||||
node_to_add = helper.make_node("Attention",
|
||||
["input_hidden_states", "c_attn.weight", "c_attn.bias", "attention_mask", "past"],
|
||||
output_names,
|
||||
"gpt2_attention",
|
||||
num_heads=num_heads,
|
||||
unidirectional=1,
|
||||
domain="com.microsoft")
|
||||
|
||||
onnx_model.remove_nodes(nodes_to_remove)
|
||||
onnx_model.add_node(node_to_add)
|
||||
onnx_model.prune_graph()
|
||||
onnx_model.save_model_to_file(optimized_onnx_path)
|
||||
|
||||
|
||||
def diff_outputs(torch_outputs, ort_outputs, index, relative=False):
|
||||
""" Returns the maximum difference between PyTorch and OnnxRuntime outputs.
|
||||
"""
|
||||
expected_outputs = torch_outputs[index].cpu().numpy()
|
||||
diff = numpy.abs(expected_outputs - ort_outputs[index])
|
||||
if relative:
|
||||
return numpy.amax(diff / (numpy.abs(expected_outputs) + 1e-6))
|
||||
else:
|
||||
return numpy.amax(diff)
|
||||
|
||||
|
||||
def compare_outputs(torch_outputs, ort_outputs, rtol=1e-03, atol=1e-03, verbose=False):
|
||||
""" Returns True if torch and ORT outputs are close for given thresholds, and False otherwise.
|
||||
"""
|
||||
is_all_close = True
|
||||
max_abs_diff = []
|
||||
for i in range(len(ort_outputs)):
|
||||
is_close = numpy.allclose(ort_outputs[i], torch_outputs[i].cpu().numpy(), rtol=rtol, atol=atol)
|
||||
if not is_close:
|
||||
is_all_close = False
|
||||
if verbose:
|
||||
print(f'output {i} ({ort_outputs[i].name}) are close: {is_close}')
|
||||
max_abs_diff.append(diff_outputs(torch_outputs, ort_outputs, i))
|
||||
|
||||
if (not is_all_close) or (verbose and max(max_abs_diff) > 0):
|
||||
messages = ["max_abs_diff per output:"]
|
||||
output_names = ["attn_output", "present"] + DEBUG_OUTPUTS
|
||||
for i, diff in enumerate(max_abs_diff):
|
||||
messages.append(f"{output_names[i]}={diff},")
|
||||
print(" ".join(messages))
|
||||
|
||||
return is_all_close, max(max_abs_diff)
|
||||
|
||||
|
||||
def create_ort_session(onnx_model_path, use_gpu=True):
|
||||
from onnxruntime import SessionOptions, InferenceSession, GraphOptimizationLevel, __version__ as onnxruntime_version
|
||||
sess_options = SessionOptions()
|
||||
sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
sess_options.intra_op_num_threads = 2
|
||||
sess_options.log_severity_level = 2
|
||||
execution_providers = ['CPUExecutionProvider'
|
||||
] if not use_gpu else ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
return InferenceSession(onnx_model_path, sess_options, providers=execution_providers)
|
||||
|
||||
|
||||
def onnxruntime_inference(ort_session, input_hidden_states, attention_mask, past):
|
||||
ort_inputs = {
|
||||
'past': numpy.ascontiguousarray(past.cpu().numpy()),
|
||||
'attention_mask': numpy.ascontiguousarray(attention_mask.cpu().numpy()),
|
||||
'input_hidden_states': numpy.ascontiguousarray(input_hidden_states.cpu().numpy()),
|
||||
}
|
||||
|
||||
ort_outputs = ort_session.run(None, ort_inputs)
|
||||
return ort_outputs
|
||||
|
||||
|
||||
def verify_attention(model,
|
||||
onnx_model_path,
|
||||
batch_size,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
sequence_length,
|
||||
past_sequence_length,
|
||||
float16,
|
||||
device,
|
||||
padding_length,
|
||||
optimized,
|
||||
test_cases=100):
|
||||
print(
|
||||
f"optimized={optimized}, batch_size={batch_size}, hidden_size={hidden_size}, num_attention_heads={num_attention_heads}, sequence_length={sequence_length}, past_sequence_length={past_sequence_length}, float16={float16}, padding_length={padding_length}, device={device}"
|
||||
)
|
||||
passed_cases = 0
|
||||
max_diffs = []
|
||||
|
||||
ort_session = create_ort_session(onnx_model_path, device.type == 'cuda')
|
||||
for i in range(test_cases):
|
||||
input_hidden_states, attention_mask, layer_past = create_inputs(batch_size, hidden_size, num_attention_heads,
|
||||
sequence_length, past_sequence_length, float16,
|
||||
device, padding_length)
|
||||
|
||||
with torch.no_grad():
|
||||
torch_outputs = model(input_hidden_states, layer_past=layer_past, attention_mask=attention_mask)
|
||||
|
||||
ort_outputs = onnxruntime_inference(ort_session, input_hidden_states, attention_mask, layer_past)
|
||||
|
||||
tolerance = 1e-03 if float16 else 1e-05
|
||||
is_all_close, max_diff = compare_outputs(torch_outputs, ort_outputs, rtol=tolerance, atol=tolerance)
|
||||
max_diffs.append(max_diff)
|
||||
if is_all_close:
|
||||
passed_cases += 1
|
||||
|
||||
max_diff = max(max_diffs)
|
||||
diff_count = len([i for i in max_diffs if i > 0])
|
||||
success_flag = "[FAILED]" if passed_cases < test_cases else "[OK]"
|
||||
print(f"{success_flag} Passed_cases={passed_cases}/{test_cases}; Max_diff={max_diff}; Diff_count={diff_count}")
|
||||
return test_cases - passed_cases
|
||||
|
||||
|
||||
def run(batch_size, float16, optimized, hidden_size, num_attention_heads, device, test_cases):
|
||||
test_name = f"batch_size={batch_size}, float16={float16}, optimized={optimized}, hidden_size={hidden_size}, num_attention_heads={num_attention_heads}"
|
||||
print(f"\nTesting ONNX parity: {test_name}")
|
||||
|
||||
debug = (not optimized
|
||||
) # or DEBUG_OUTPUTS==["softmax"] when you add an extra output for softmax result in Attention operator
|
||||
model = MyGPT2Attention(hidden_size=hidden_size, num_attention_heads=num_attention_heads, debug=debug)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
if float16:
|
||||
model.half()
|
||||
|
||||
# Do not re-use onnx file from previous test since weights of model are random.
|
||||
onnx_model_path = './temp/gpt_attention_{}.onnx'.format("fp16" if float16 else "fp32")
|
||||
export_onnx(model, onnx_model_path, float16, hidden_size, num_attention_heads, debug, device)
|
||||
|
||||
if optimized:
|
||||
optimized_onnx_path = './temp/gpt_attention_opt_{}.onnx'.format("fp16" if float16 else "fp32")
|
||||
optimize_onnx(onnx_model_path, optimized_onnx_path, num_attention_heads, debug)
|
||||
onnx_path = optimized_onnx_path
|
||||
else:
|
||||
onnx_path = onnx_model_path
|
||||
|
||||
# Test Case: No past state
|
||||
sequence_length = 2
|
||||
past_sequence_length = 0
|
||||
padding_length = 0
|
||||
num_failure = 0
|
||||
num_failure += verify_attention(model,
|
||||
onnx_path,
|
||||
batch_size,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
sequence_length,
|
||||
past_sequence_length,
|
||||
float16,
|
||||
device,
|
||||
padding_length,
|
||||
optimized,
|
||||
test_cases)
|
||||
|
||||
# Test Case: with past state and padding last 2 words
|
||||
sequence_length = 3
|
||||
past_sequence_length = 5
|
||||
padding_length = 2
|
||||
num_failure += verify_attention(model,
|
||||
onnx_path,
|
||||
batch_size,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
sequence_length,
|
||||
past_sequence_length,
|
||||
float16,
|
||||
device,
|
||||
padding_length,
|
||||
optimized,
|
||||
test_cases)
|
||||
|
||||
# Test Case: random mask one word
|
||||
sequence_length = 1
|
||||
past_sequence_length = 128
|
||||
padding_length = -1
|
||||
num_failure += verify_attention(model,
|
||||
onnx_path,
|
||||
batch_size,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
sequence_length,
|
||||
past_sequence_length,
|
||||
float16,
|
||||
device,
|
||||
padding_length,
|
||||
optimized,
|
||||
test_cases)
|
||||
|
||||
# clean up onnx file
|
||||
os.remove(onnx_model_path)
|
||||
if optimized:
|
||||
os.remove(onnx_path)
|
||||
|
||||
return num_failure, test_name
|
||||
|
||||
|
||||
class TestGptAttentionHuggingfaceParity(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.optimized = True # Change it to False if you want to test parity of non optimized ONNX
|
||||
self.test_cases = 10 # Number of test cases per test run
|
||||
|
||||
def run_test(self, batch_size, float16, optimized, hidden_size, num_attention_heads, device):
|
||||
if float16 and device.type=='cpu': # CPU does not support FP16
|
||||
return
|
||||
num_failure, test_name = run(batch_size, float16, optimized, hidden_size, num_attention_heads, device, self.test_cases)
|
||||
self.assertTrue(num_failure == 0, test_name)
|
||||
|
||||
def run_small(self, optimized, device):
|
||||
for batch_size in [64]:
|
||||
self.run_test(batch_size, float16=False, optimized=optimized, hidden_size=768, num_attention_heads=12, device=device)
|
||||
self.run_test(batch_size, float16=True, optimized=optimized, hidden_size=768, num_attention_heads=12, device=device)
|
||||
|
||||
def run_large(self, optimized, device):
|
||||
for batch_size in [2]:
|
||||
self.run_test(batch_size, float16=False, optimized=optimized, hidden_size=4096, num_attention_heads=32, device=device)
|
||||
self.run_test(batch_size, float16=True, optimized=optimized, hidden_size=4096, num_attention_heads=32, device=device)
|
||||
|
||||
def test_cpu(self):
|
||||
cpu = torch.device('cpu')
|
||||
self.run_small(self.optimized, cpu)
|
||||
|
||||
def test_cuda(self):
|
||||
if not torch.cuda.is_available():
|
||||
import pytest
|
||||
pytest.skip('test requires GPU and torch+cuda')
|
||||
else:
|
||||
gpu = torch.device('cuda')
|
||||
self.run_small(self.optimized, gpu)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_large_cuda(self):
|
||||
if not torch.cuda.is_available():
|
||||
import pytest
|
||||
pytest.skip('test requires GPU and torch+cuda')
|
||||
else:
|
||||
gpu = torch.device('cuda')
|
||||
self.run_large(self.optimized, gpu)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in a new issue