Enable dense sequence optimized version of Pytorch exported BERT-L on AMD GPU (#6504)

* Permit dense seq optimization on BERT-L pytorch export by enabling ReduceSumTraining, Equal, and NonZero on AMD

* enable Equal tests

* enable fast_matrix_reduction test case
This commit is contained in:
Suffian Khan 2021-01-29 16:12:34 -05:00 committed by GitHub
parent 8c6d76a4c0
commit 76bc0e479c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 37 additions and 24 deletions

View file

@ -34,9 +34,9 @@ __global__ void NonZeroCountEachBlockKernel(const InputT* x, int64_t x_size, int
__shared__ typename BlockReduceT::TempStorage temp_storage;
int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
const cub::CastOp<bool> cast_to_bool;
// const cub::CastOp<bool> cast_to_bool; not supported on amd hipcub
int nz = 0;
if (index < x_size && cast_to_bool(x[index])) ++nz;
if (index < x_size && bool(x[index])) ++nz;
int count = BlockReduceT(temp_storage).Sum(nz);
if (threadIdx.x == 0) {
@ -52,15 +52,15 @@ __global__ void NonZeroOutputPositionsKernel(
__shared__ typename BlockScanT::TempStorage temp_storage;
int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
const cub::CastOp<bool> cast_to_bool;
// const cub::CastOp<bool> cast_to_bool; not supported on amd hipcub
int nz = 0;
if (index < x_size && cast_to_bool(x[index])) ++nz;
if (index < x_size && bool(x[index])) ++nz;
int pos_in_block = 0;
BlockScanT(temp_storage).InclusiveSum(nz, pos_in_block);
int result_position = ((blockIdx.x == 0) ? 0 : prefix_counts[blockIdx.x - 1]) + pos_in_block - nz;
if (index < x_size && cast_to_bool(x[index])) {
if (index < x_size && bool(x[index])) {
int remain = (int)index, dim = 0;
for (int axis = 0, rp = result_position; axis < x_rank; ++axis, rp += nonzero_elements) {
x_strides[axis].divmod(remain, dim, remain);

View file

@ -1295,11 +1295,11 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, int32_t, Where)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, int64_t, Where)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, uint8_t, Where)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, bool, NonZero)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, uint8_t, NonZero)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int32_t, NonZero)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int64_t, NonZero)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, NonZero)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, bool, NonZero)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, uint8_t, NonZero)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int32_t, NonZero)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int64_t, NonZero)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, NonZero)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 9, TopK)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 8, 8, Scan)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 10, Scan)>,
@ -1427,9 +1427,9 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, Pad)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, Pad)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, Pad)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, bool, Equal)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, Equal)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int64_t, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, bool, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int64_t, Equal)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, Round)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, Round)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, Round)>,

View file

@ -246,6 +246,23 @@ TEST(ReductionOpTest, ReduceSumTraining_int32) {
test.Run();
}
TEST(ReductionOpTest, ReduceSumTraining_fast_matrix_reduction) {
OpTester test("ReduceSumTraining", 1, onnxruntime::kMSDomain);
test.AddAttribute("keepdims", (int64_t)1);
test.AddInput<float>("data", {3, 4},
{1.0f, 2.0f,
3.0f, 4.0f,
5.0f, 6.0f,
7.0f, 8.0f,
9.0f, 10.0f,
11.0f, 12.0f});
test.AddInput<int64_t>("axes", {2}, {0, 1}, true /*is_initializer*/);
test.AddOutput<float>("reduced", {1, 1}, {78.0f});
test.Run();
}
TEST(ReductionOpTest, ReduceSumTraining_default_axes_keepdims) {
OpTester test("ReduceSumTraining", 1, onnxruntime::kMSDomain);
test.AddAttribute("keepdims", (int64_t)1);

View file

@ -138,10 +138,10 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, View)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Group)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SGDOptimizer)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ReduceSumTraining)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ReduceSumTraining)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ReduceSumTraining)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ReduceSumTraining)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, int32_t, ReduceSumTraining)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ReduceSumTraining)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ReduceSumTraining)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SplitTraining)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, ConcatTraining)>,
// Adam

View file

@ -166,10 +166,6 @@ provider_excluded_files = [
'tensor/gather_elements_impl.cu',
'tensor/gather_elements_impl.h',
'tensor/gather_nd_impl.cu',
'tensor/nonzero_impl.cu',
'tensor/nonzero_impl.h',
'tensor/nonzero_op.cc',
'tensor/nonzero_op.h',
'tensor/pad.cc',
'tensor/pad.h',
'tensor/pad_impl.cu',

View file

@ -102,6 +102,10 @@ ReductionOpTest.ReduceSumSquare_do_not_keepdims
ReductionOpTest.ReduceSumSquare_do_not_keepdims_2
ReductionOpTest.ReduceSumSquare_keepdims
ReductionOpTest.ReduceSumSquare0DTensor
ReductionOpTest.ReduceSumTraining_default_axes_keepdims
ReductionOpTest.ReduceSumTraining_axes_not_initializer
ReductionOpTest.ReduceSumTraining_do_not_keepdims
ReductionOpTest.ReduceSumTraining_neg_axis
ReductionOpTest.ReduceProd_default_axes_keepdims
ReductionOpTest.ReduceProd_default_axes_do_not_keep_dims
ReductionOpTest.ReduceProd_do_not_keepdims
@ -147,10 +151,6 @@ MathOpTest.Less_multidiretional_broadcastBA
MathOpTest.Greater_broadcastBA
MathOpTest.Greater_multidiretional_broadcastAB
MathOpTest.Greater_multidiretional_broadcastBA
MathOpTest.Equal_broadcastBA
MathOpTest.Equal_multidiretional_broadcastAB
MathOpTest.Equal_multidiretional_broadcastBA
MathOpTest.Equal_multidiretional_broadcastAB_bool
MathOpTest.Pow_int64_float
MathOpTest.Pow_int32_float
MathOpTest.Pow_int64_double