mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
Move zero point inputs of MatmulInteger to CPU memory (#3159)
This commit is contained in:
parent
51a8c82908
commit
3de1fc096d
3 changed files with 27 additions and 29 deletions
|
|
@ -18,6 +18,8 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
|
|||
int8_t,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(2)
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(3)
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int8_t>())
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<int8_t>())
|
||||
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int32_t>()),
|
||||
|
|
@ -58,19 +60,19 @@ Status MatMulInteger<int8_t, int8_t>::ComputeInternal(OpKernelContext* ctx) cons
|
|||
int32_t* output_ptr = Y->template MutableData<int32_t>();
|
||||
|
||||
// validate zero points
|
||||
const int8_t* a_offset = nullptr;
|
||||
const int8_t* b_offset = nullptr;
|
||||
int8_t a_offset = 0;
|
||||
int8_t b_offset = 0;
|
||||
if (has_a_zero_point_) {
|
||||
auto a_zero_point = ctx->Input<Tensor>(2);
|
||||
ORT_ENFORCE(IsScalarOr1ElementVector(a_zero_point),
|
||||
"MatmulInteger : input1 zero point must be a scalar or 1D tensor of size 1");
|
||||
a_offset = a_zero_point->template Data<int8_t>();
|
||||
a_offset = *(a_zero_point->template Data<int8_t>());
|
||||
}
|
||||
if (has_b_zero_point_) {
|
||||
auto b_zero_point = ctx->Input<Tensor>(3);
|
||||
ORT_ENFORCE(IsScalarOr1ElementVector(b_zero_point),
|
||||
"MatmulInteger : input2 zero point must be a scalar or 1D tensor of size 1");
|
||||
b_offset = b_zero_point->template Data<int8_t>();
|
||||
b_offset = *(b_zero_point->template Data<int8_t>());
|
||||
}
|
||||
|
||||
// offset output c[i,j] to
|
||||
|
|
@ -81,20 +83,20 @@ Status MatMulInteger<int8_t, int8_t>::ComputeInternal(OpKernelContext* ctx) cons
|
|||
// ReduceColSumOnMatrixB computes the a_offset * (b[0,j] + b[1,j] ... + b[k,j]) part
|
||||
// OffsetOutput computes gets the final result
|
||||
IAllocatorUniquePtr<int32_t> a_row_buf;
|
||||
if (has_b_zero_point_) {
|
||||
if (b_offset != 0) {
|
||||
a_row_buf = GetScratchBuffer<int32_t>(helper.OutputShape().Size() / helper.N());
|
||||
ORT_RETURN_IF_ERROR(ReduceRowSumOnMatrixA(a_ptr, a_row_buf.get(), b_offset, helper));
|
||||
}
|
||||
|
||||
IAllocatorUniquePtr<int32_t> b_col_buf;
|
||||
if (has_a_zero_point_) {
|
||||
if (a_offset != 0) {
|
||||
b_col_buf = GetScratchBuffer<int32_t>(helper.OutputShape().Size() / helper.M());
|
||||
ORT_RETURN_IF_ERROR(ReduceColSumOnMatrixB(b_ptr, b_col_buf.get(), a_offset, helper));
|
||||
}
|
||||
|
||||
int alpha = 1;
|
||||
int beta = 0;
|
||||
if (has_a_zero_point_ || has_b_zero_point_) {
|
||||
if (a_offset != 0 || b_offset != 0) {
|
||||
OffsetOutput(a_row_buf.get(),
|
||||
b_col_buf.get(),
|
||||
output_ptr,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ namespace onnxruntime {
|
|||
namespace cuda {
|
||||
|
||||
template <int TPB>
|
||||
__global__ void ReduceRowSumOnMatrixAKernel(const int8_t* matrix, int32_t* row_sum, const int8_t* offset, int32_t K) {
|
||||
__global__ void ReduceRowSumOnMatrixAKernel(const int8_t* matrix, int32_t* row_sum, const int8_t offset, int32_t K) {
|
||||
int32_t thread_data = 0;
|
||||
const int8_t* row_ptr = matrix + blockIdx.x * K;
|
||||
for (int i = threadIdx.x; i < K; i += TPB) {
|
||||
|
|
@ -22,11 +22,11 @@ __global__ void ReduceRowSumOnMatrixAKernel(const int8_t* matrix, int32_t* row_s
|
|||
int32_t sum = BlockReduce(temp_storage).Sum(thread_data);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
row_sum[blockIdx.x] = (*offset) * sum;
|
||||
row_sum[blockIdx.x] = offset * sum;
|
||||
}
|
||||
}
|
||||
|
||||
Status ReduceRowSumOnMatrixA(const int8_t* matrix, int32_t* row_sum, const int8_t* offset, const MatMulComputeHelper& helper) {
|
||||
Status ReduceRowSumOnMatrixA(const int8_t* matrix, int32_t* row_sum, const int8_t offset, const MatMulComputeHelper& helper) {
|
||||
for (size_t batch = 0; batch < helper.OutputOffsets().size(); batch++) {
|
||||
ReduceRowSumOnMatrixAKernel<static_cast<int>(GridDim::maxThreadsPerBlock)>
|
||||
<<<static_cast<int>(helper.M()), GridDim::maxThreadsPerBlock, 0>>>(matrix + helper.LeftOffsets()[batch],
|
||||
|
|
@ -35,11 +35,11 @@ Status ReduceRowSumOnMatrixA(const int8_t* matrix, int32_t* row_sum, const int8_
|
|||
static_cast<int>(helper.K()));
|
||||
}
|
||||
|
||||
return CUDA_CALL( cudaPeekAtLastError() ) ? Status::OK() : Status( common::ONNXRUNTIME, common::FAIL );
|
||||
return CUDA_CALL(cudaPeekAtLastError()) ? Status::OK() : Status(common::ONNXRUNTIME, common::FAIL);
|
||||
}
|
||||
|
||||
template <int TPB>
|
||||
__global__ void ReduceColSumOnMatrixBKernel(const int8_t* matrix, int32_t* col_sum, const int8_t* offset, int32_t row, int32_t col) {
|
||||
__global__ void ReduceColSumOnMatrixBKernel(const int8_t* matrix, int32_t* col_sum, const int8_t offset, int32_t row, int32_t col) {
|
||||
int32_t thread_data = 0;
|
||||
const int8_t* col_ptr = matrix + blockIdx.x;
|
||||
for (int i = threadIdx.x; i < row; i += TPB) {
|
||||
|
|
@ -51,11 +51,11 @@ __global__ void ReduceColSumOnMatrixBKernel(const int8_t* matrix, int32_t* col_s
|
|||
int32_t sum = BlockReduce(temp_storage).Sum(thread_data);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
col_sum[blockIdx.x] = (*offset) * sum;
|
||||
col_sum[blockIdx.x] = offset * sum;
|
||||
}
|
||||
}
|
||||
|
||||
Status ReduceColSumOnMatrixB(const int8_t* matrix, int32_t* col_sum, const int8_t* offset, const MatMulComputeHelper& helper) {
|
||||
Status ReduceColSumOnMatrixB(const int8_t* matrix, int32_t* col_sum, const int8_t offset, const MatMulComputeHelper& helper) {
|
||||
for (size_t batch = 0; batch < helper.OutputOffsets().size(); batch++) {
|
||||
ReduceColSumOnMatrixBKernel<static_cast<int>(GridDim::maxThreadsPerBlock)>
|
||||
<<<static_cast<int>(helper.N()), GridDim::maxThreadsPerBlock, 0>>>(matrix + helper.RightOffsets()[batch],
|
||||
|
|
@ -65,18 +65,16 @@ Status ReduceColSumOnMatrixB(const int8_t* matrix, int32_t* col_sum, const int8_
|
|||
static_cast<int32_t>(helper.N()));
|
||||
}
|
||||
|
||||
return CUDA_CALL( cudaPeekAtLastError() ) ? Status::OK() : Status( common::ONNXRUNTIME, common::FAIL );
|
||||
return CUDA_CALL(cudaPeekAtLastError()) ? Status::OK() : Status(common::ONNXRUNTIME, common::FAIL);
|
||||
}
|
||||
|
||||
__global__ void ComputeOffsetOfMatrixAB(const int32_t* row_sum,
|
||||
const int32_t* col_sum,
|
||||
int32_t* output,
|
||||
const int8_t* a_offset,
|
||||
const int8_t* b_offset,
|
||||
int32_t K,
|
||||
int32_t K_A_B,
|
||||
int32_t N) {
|
||||
for (int32_t i = threadIdx.x; i < N; i += blockDim.x) {
|
||||
*(output + blockIdx.x * N + i) = K * (*a_offset) * (*b_offset) - row_sum[blockIdx.x] - col_sum[i];
|
||||
*(output + blockIdx.x * N + i) = K_A_B - row_sum[blockIdx.x] - col_sum[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -99,8 +97,8 @@ __global__ void ComputeOffsetOfMatrixB(const int32_t* row_sum,
|
|||
Status OffsetOutput(const int32_t* row_sum,
|
||||
const int32_t* col_sum,
|
||||
int32_t* output,
|
||||
const int8_t* a_offset,
|
||||
const int8_t* b_offset,
|
||||
const int8_t a_offset,
|
||||
const int8_t b_offset,
|
||||
const MatMulComputeHelper& helper) {
|
||||
if (a_offset && b_offset) {
|
||||
for (size_t batch = 0; batch < helper.OutputOffsets().size(); batch++) {
|
||||
|
|
@ -108,9 +106,7 @@ Status OffsetOutput(const int32_t* row_sum,
|
|||
row_sum + batch * helper.M(),
|
||||
col_sum + batch * helper.N(),
|
||||
output + helper.OutputOffsets()[batch],
|
||||
a_offset,
|
||||
b_offset,
|
||||
static_cast<int32_t>(helper.K()),
|
||||
static_cast<int32_t>(helper.K()) * a_offset * b_offset,
|
||||
static_cast<int32_t>(helper.N()));
|
||||
}
|
||||
} else if (a_offset) {
|
||||
|
|
@ -129,7 +125,7 @@ Status OffsetOutput(const int32_t* row_sum,
|
|||
}
|
||||
}
|
||||
|
||||
return CUDA_CALL( cudaPeekAtLastError() ) ? Status::OK() : Status( common::ONNXRUNTIME, common::FAIL );
|
||||
return CUDA_CALL(cudaPeekAtLastError()) ? Status::OK() : Status(common::ONNXRUNTIME, common::FAIL);
|
||||
}
|
||||
|
||||
__global__ void PadMatrixInLeadingDimensionKernel(const int8_t* src, int8_t* dst, int col_src, int col_dst) {
|
||||
|
|
|
|||
|
|
@ -11,13 +11,13 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
Status ReduceRowSumOnMatrixA(const int8_t* matrix, int32_t* row_sum, const int8_t* offset, const MatMulComputeHelper& helper);
|
||||
Status ReduceColSumOnMatrixB(const int8_t* matrix, int32_t* col_sum, const int8_t* offset, const MatMulComputeHelper& helper);
|
||||
Status ReduceRowSumOnMatrixA(const int8_t* matrix, int32_t* row_sum, const int8_t offset, const MatMulComputeHelper& helper);
|
||||
Status ReduceColSumOnMatrixB(const int8_t* matrix, int32_t* col_sum, const int8_t offset, const MatMulComputeHelper& helper);
|
||||
Status OffsetOutput(const int32_t* row_sum,
|
||||
const int32_t* col_sum,
|
||||
int32_t* output,
|
||||
const int8_t* a_offset,
|
||||
const int8_t* b_offset,
|
||||
const int8_t a_offset,
|
||||
const int8_t b_offset,
|
||||
const MatMulComputeHelper& helper);
|
||||
|
||||
Status PadMatrixInLeadingDimension(const int8_t* src, int8_t* dst, int64_t row, int64_t col, int64_t pad_size);
|
||||
|
|
|
|||
Loading…
Reference in a new issue