mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-26 03:00:54 +00:00
Enable dense sequence optimized version of Pytorch exported BERT-L on AMD GPU (#6504)
* Permit dense seq optimization on BERT-L pytorch export by enabling ReduceSumTraining, Equal, and NonZero on AMD * enable Equal tests * enable fast_matrix_reduction test case
This commit is contained in:
parent
8c6d76a4c0
commit
76bc0e479c
6 changed files with 37 additions and 24 deletions
|
|
@ -34,9 +34,9 @@ __global__ void NonZeroCountEachBlockKernel(const InputT* x, int64_t x_size, int
|
|||
__shared__ typename BlockReduceT::TempStorage temp_storage;
|
||||
|
||||
int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const cub::CastOp<bool> cast_to_bool;
|
||||
// const cub::CastOp<bool> cast_to_bool; not supported on amd hipcub
|
||||
int nz = 0;
|
||||
if (index < x_size && cast_to_bool(x[index])) ++nz;
|
||||
if (index < x_size && bool(x[index])) ++nz;
|
||||
int count = BlockReduceT(temp_storage).Sum(nz);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
|
|
@ -52,15 +52,15 @@ __global__ void NonZeroOutputPositionsKernel(
|
|||
__shared__ typename BlockScanT::TempStorage temp_storage;
|
||||
|
||||
int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const cub::CastOp<bool> cast_to_bool;
|
||||
// const cub::CastOp<bool> cast_to_bool; not supported on amd hipcub
|
||||
int nz = 0;
|
||||
if (index < x_size && cast_to_bool(x[index])) ++nz;
|
||||
if (index < x_size && bool(x[index])) ++nz;
|
||||
int pos_in_block = 0;
|
||||
BlockScanT(temp_storage).InclusiveSum(nz, pos_in_block);
|
||||
|
||||
int result_position = ((blockIdx.x == 0) ? 0 : prefix_counts[blockIdx.x - 1]) + pos_in_block - nz;
|
||||
|
||||
if (index < x_size && cast_to_bool(x[index])) {
|
||||
if (index < x_size && bool(x[index])) {
|
||||
int remain = (int)index, dim = 0;
|
||||
for (int axis = 0, rp = result_position; axis < x_rank; ++axis, rp += nonzero_elements) {
|
||||
x_strides[axis].divmod(remain, dim, remain);
|
||||
|
|
|
|||
|
|
@ -1295,11 +1295,11 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, int32_t, Where)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, int64_t, Where)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, uint8_t, Where)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, bool, NonZero)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, uint8_t, NonZero)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int32_t, NonZero)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int64_t, NonZero)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, bool, NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, uint8_t, NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int32_t, NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int64_t, NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, NonZero)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 9, TopK)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 8, 8, Scan)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 10, Scan)>,
|
||||
|
|
@ -1427,9 +1427,9 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, Pad)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, Pad)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, Pad)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, bool, Equal)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, Equal)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int64_t, Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, bool, Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int64_t, Equal)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, Round)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, Round)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, Round)>,
|
||||
|
|
|
|||
|
|
@ -246,6 +246,23 @@ TEST(ReductionOpTest, ReduceSumTraining_int32) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ReductionOpTest, ReduceSumTraining_fast_matrix_reduction) {
|
||||
OpTester test("ReduceSumTraining", 1, onnxruntime::kMSDomain);
|
||||
test.AddAttribute("keepdims", (int64_t)1);
|
||||
test.AddInput<float>("data", {3, 4},
|
||||
{1.0f, 2.0f,
|
||||
3.0f, 4.0f,
|
||||
|
||||
5.0f, 6.0f,
|
||||
7.0f, 8.0f,
|
||||
|
||||
9.0f, 10.0f,
|
||||
11.0f, 12.0f});
|
||||
test.AddInput<int64_t>("axes", {2}, {0, 1}, true /*is_initializer*/);
|
||||
test.AddOutput<float>("reduced", {1, 1}, {78.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ReductionOpTest, ReduceSumTraining_default_axes_keepdims) {
|
||||
OpTester test("ReduceSumTraining", 1, onnxruntime::kMSDomain);
|
||||
test.AddAttribute("keepdims", (int64_t)1);
|
||||
|
|
|
|||
|
|
@ -138,10 +138,10 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, View)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Group)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SGDOptimizer)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ReduceSumTraining)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ReduceSumTraining)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ReduceSumTraining)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ReduceSumTraining)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, int32_t, ReduceSumTraining)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ReduceSumTraining)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ReduceSumTraining)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SplitTraining)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, ConcatTraining)>,
|
||||
// Adam
|
||||
|
|
|
|||
|
|
@ -166,10 +166,6 @@ provider_excluded_files = [
|
|||
'tensor/gather_elements_impl.cu',
|
||||
'tensor/gather_elements_impl.h',
|
||||
'tensor/gather_nd_impl.cu',
|
||||
'tensor/nonzero_impl.cu',
|
||||
'tensor/nonzero_impl.h',
|
||||
'tensor/nonzero_op.cc',
|
||||
'tensor/nonzero_op.h',
|
||||
'tensor/pad.cc',
|
||||
'tensor/pad.h',
|
||||
'tensor/pad_impl.cu',
|
||||
|
|
|
|||
|
|
@ -102,6 +102,10 @@ ReductionOpTest.ReduceSumSquare_do_not_keepdims
|
|||
ReductionOpTest.ReduceSumSquare_do_not_keepdims_2
|
||||
ReductionOpTest.ReduceSumSquare_keepdims
|
||||
ReductionOpTest.ReduceSumSquare0DTensor
|
||||
ReductionOpTest.ReduceSumTraining_default_axes_keepdims
|
||||
ReductionOpTest.ReduceSumTraining_axes_not_initializer
|
||||
ReductionOpTest.ReduceSumTraining_do_not_keepdims
|
||||
ReductionOpTest.ReduceSumTraining_neg_axis
|
||||
ReductionOpTest.ReduceProd_default_axes_keepdims
|
||||
ReductionOpTest.ReduceProd_default_axes_do_not_keep_dims
|
||||
ReductionOpTest.ReduceProd_do_not_keepdims
|
||||
|
|
@ -147,10 +151,6 @@ MathOpTest.Less_multidiretional_broadcastBA
|
|||
MathOpTest.Greater_broadcastBA
|
||||
MathOpTest.Greater_multidiretional_broadcastAB
|
||||
MathOpTest.Greater_multidiretional_broadcastBA
|
||||
MathOpTest.Equal_broadcastBA
|
||||
MathOpTest.Equal_multidiretional_broadcastAB
|
||||
MathOpTest.Equal_multidiretional_broadcastBA
|
||||
MathOpTest.Equal_multidiretional_broadcastAB_bool
|
||||
MathOpTest.Pow_int64_float
|
||||
MathOpTest.Pow_int32_float
|
||||
MathOpTest.Pow_int64_double
|
||||
|
|
|
|||
Loading…
Reference in a new issue