mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
fix static analysis: integer_gemm and attention_quantization (#13004)
This commit is contained in:
parent
454f77cd94
commit
dd39f0293d
2 changed files with 13 additions and 11 deletions
|
|
@ -4,6 +4,7 @@
|
|||
#include "attention_quantization.h"
|
||||
#include "attention_quantization_impl.cuh"
|
||||
#include "contrib_ops/cuda/bert/attention_impl.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/shared_inc/fpgeneric.h"
|
||||
#include "core/providers/cuda/shared_inc/integer_gemm.h"
|
||||
|
|
@ -112,18 +113,18 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
|
|||
past_tensor));
|
||||
|
||||
const auto& shape = input->Shape();
|
||||
int batch_size = static_cast<int>(shape[0]);
|
||||
int sequence_length = static_cast<int>(shape[1]);
|
||||
int input_hidden_size = static_cast<int>(shape[2]);
|
||||
int batch_size = SafeInt<int>(shape[0]);
|
||||
int sequence_length = SafeInt<int>(shape[1]);
|
||||
int input_hidden_size = SafeInt<int>(shape[2]);
|
||||
|
||||
const auto& bias_shape = bias->Shape();
|
||||
const int hidden_size = static_cast<int>(bias_shape.GetDims()[0]) / 3;
|
||||
const int hidden_size = SafeInt<int>(bias_shape.GetDims()[0]) / 3;
|
||||
const int head_size = hidden_size / num_heads_;
|
||||
|
||||
TensorShapeVector output_shape(3);
|
||||
output_shape[0] = shape[0];
|
||||
output_shape[1] = shape[1];
|
||||
output_shape[2] = static_cast<int64_t>(hidden_size);
|
||||
output_shape[2] = SafeInt<int64_t>(hidden_size);
|
||||
Tensor* output = context->Output(0, output_shape);
|
||||
|
||||
cublasHandle_t cublas = CublasHandle();
|
||||
|
|
@ -133,8 +134,8 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
|
|||
int m = batch_size * sequence_length;
|
||||
int n = 3 * hidden_size;
|
||||
int k = input_hidden_size;
|
||||
auto gemm_buffer = GetScratchBuffer<T>(batch_size * sequence_length * 3 * hidden_size * element_size);
|
||||
auto gemm_buffer_quantized = GetScratchBuffer<int32_t>(batch_size * sequence_length * 3 * hidden_size);
|
||||
auto gemm_buffer = GetScratchBuffer<T>(SafeInt<size_t>(batch_size) * sequence_length * 3 * hidden_size * element_size);
|
||||
auto gemm_buffer_quantized = GetScratchBuffer<int32_t>(SafeInt<size_t>(batch_size) * sequence_length * 3 * hidden_size);
|
||||
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
|
||||
|
|
|
|||
|
|
@ -3,13 +3,14 @@
|
|||
|
||||
#include "core/providers/cuda/shared_inc/integer_gemm.h"
|
||||
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/shared_inc/cuda_call.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
inline int roundoff(int v, int d) {
|
||||
constexpr int roundoff(int v, int d) {
|
||||
return (v + d - 1) / d * d;
|
||||
}
|
||||
|
||||
|
|
@ -27,12 +28,12 @@ Status GemmInt8(int m, int n, int k,
|
|||
// 1. leading dimension is multiples of 4
|
||||
// 2. A, B is 32-bit aligned
|
||||
|
||||
const int mask = 0x1F;
|
||||
constexpr int mask = 0x1F;
|
||||
int lda_aligned = lda;
|
||||
IAllocatorUniquePtr<int8_t> a_padded;
|
||||
if ((mask & lda_aligned) != 0) {
|
||||
lda_aligned = roundoff(lda, 32);
|
||||
a_padded = cuda_kernel->GetScratchBuffer<int8_t>(m * lda_aligned);
|
||||
a_padded = cuda_kernel->GetScratchBuffer<int8_t>(SafeInt<size_t>(m) * lda_aligned);
|
||||
cudaMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, cudaMemcpyDeviceToDevice, stream);
|
||||
}
|
||||
|
||||
|
|
@ -40,7 +41,7 @@ Status GemmInt8(int m, int n, int k,
|
|||
IAllocatorUniquePtr<int8_t> b_padded;
|
||||
if ((mask & ldb_aligned) != 0) {
|
||||
ldb_aligned = roundoff(ldb, 32);
|
||||
b_padded = cuda_kernel->GetScratchBuffer<int8_t>(k * ldb_aligned);
|
||||
b_padded = cuda_kernel->GetScratchBuffer<int8_t>(SafeInt<size_t>(k) * ldb_aligned);
|
||||
cudaMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, cudaMemcpyDeviceToDevice, stream);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue