From 76bc0e479c9988152598bcffde96b7f77fbcbe1f Mon Sep 17 00:00:00 2001 From: Suffian Khan Date: Fri, 29 Jan 2021 16:12:34 -0500 Subject: [PATCH] Enable dense sequence optimized version of Pytorch exported BERT-L on AMD GPU (#6504) * Permit dense seq optimization on BERT-L pytorch export by enabling ReduceSumTraining, Equal, and NonZero on AMD * enable Equal tests * enable fast_matrix_reduction test case --- .../core/providers/cuda/tensor/nonzero_impl.cu | 10 +++++----- .../providers/rocm/rocm_execution_provider.cc | 16 ++++++++-------- .../cpu/reduction/reduction_ops_test.cc | 17 +++++++++++++++++ .../training_ops/rocm/rocm_training_kernels.cc | 6 +++--- tools/ci_build/amd_hipify.py | 4 ---- .../ci_build/github/pai/pai-excluded-tests.txt | 8 ++++---- 6 files changed, 37 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/providers/cuda/tensor/nonzero_impl.cu b/onnxruntime/core/providers/cuda/tensor/nonzero_impl.cu index a3ba37402f..1ac1ae79e9 100644 --- a/onnxruntime/core/providers/cuda/tensor/nonzero_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/nonzero_impl.cu @@ -34,9 +34,9 @@ __global__ void NonZeroCountEachBlockKernel(const InputT* x, int64_t x_size, int __shared__ typename BlockReduceT::TempStorage temp_storage; int64_t index = blockIdx.x * blockDim.x + threadIdx.x; - const cub::CastOp cast_to_bool; + // const cub::CastOp cast_to_bool; not supported on amd hipcub int nz = 0; - if (index < x_size && cast_to_bool(x[index])) ++nz; + if (index < x_size && bool(x[index])) ++nz; int count = BlockReduceT(temp_storage).Sum(nz); if (threadIdx.x == 0) { @@ -52,15 +52,15 @@ __global__ void NonZeroOutputPositionsKernel( __shared__ typename BlockScanT::TempStorage temp_storage; int64_t index = blockIdx.x * blockDim.x + threadIdx.x; - const cub::CastOp cast_to_bool; + // const cub::CastOp cast_to_bool; not supported on amd hipcub int nz = 0; - if (index < x_size && cast_to_bool(x[index])) ++nz; + if (index < x_size && bool(x[index])) ++nz; int pos_in_block = 0; BlockScanT(temp_storage).InclusiveSum(nz, pos_in_block); int result_position = ((blockIdx.x == 0) ? 0 : prefix_counts[blockIdx.x - 1]) + pos_in_block - nz; - if (index < x_size && cast_to_bool(x[index])) { + if (index < x_size && bool(x[index])) { int remain = (int)index, dim = 0; for (int axis = 0, rp = result_position; axis < x_rank; ++axis, rp += nonzero_elements) { x_strides[axis].divmod(remain, dim, remain); diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index d5e0ee42b2..6df7546236 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1295,11 +1295,11 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, @@ -1427,9 +1427,9 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, diff --git a/orttraining/orttraining/test/training_ops/cpu/reduction/reduction_ops_test.cc b/orttraining/orttraining/test/training_ops/cpu/reduction/reduction_ops_test.cc index 5cca525229..18218b76f2 100644 --- a/orttraining/orttraining/test/training_ops/cpu/reduction/reduction_ops_test.cc +++ b/orttraining/orttraining/test/training_ops/cpu/reduction/reduction_ops_test.cc @@ -246,6 +246,23 @@ TEST(ReductionOpTest, ReduceSumTraining_int32) { test.Run(); } +TEST(ReductionOpTest, ReduceSumTraining_fast_matrix_reduction) { + OpTester test("ReduceSumTraining", 1, onnxruntime::kMSDomain); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 4}, + {1.0f, 2.0f, + 3.0f, 4.0f, + + 5.0f, 6.0f, + 7.0f, 8.0f, + + 9.0f, 10.0f, + 11.0f, 12.0f}); + test.AddInput("axes", {2}, {0, 1}, true /*is_initializer*/); + test.AddOutput("reduced", {1, 1}, {78.0f}); + test.Run(); +} + TEST(ReductionOpTest, ReduceSumTraining_default_axes_keepdims) { OpTester test("ReduceSumTraining", 1, onnxruntime::kMSDomain); test.AddAttribute("keepdims", (int64_t)1); diff --git a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc index 85d887ec81..d76b62ded5 100644 --- a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc +++ b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc @@ -138,10 +138,10 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // Adam diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index bc638fddec..6b498b025c 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -166,10 +166,6 @@ provider_excluded_files = [ 'tensor/gather_elements_impl.cu', 'tensor/gather_elements_impl.h', 'tensor/gather_nd_impl.cu', - 'tensor/nonzero_impl.cu', - 'tensor/nonzero_impl.h', - 'tensor/nonzero_op.cc', - 'tensor/nonzero_op.h', 'tensor/pad.cc', 'tensor/pad.h', 'tensor/pad_impl.cu', diff --git a/tools/ci_build/github/pai/pai-excluded-tests.txt b/tools/ci_build/github/pai/pai-excluded-tests.txt index 6ed3d24ce7..072e82fe56 100644 --- a/tools/ci_build/github/pai/pai-excluded-tests.txt +++ b/tools/ci_build/github/pai/pai-excluded-tests.txt @@ -102,6 +102,10 @@ ReductionOpTest.ReduceSumSquare_do_not_keepdims ReductionOpTest.ReduceSumSquare_do_not_keepdims_2 ReductionOpTest.ReduceSumSquare_keepdims ReductionOpTest.ReduceSumSquare0DTensor +ReductionOpTest.ReduceSumTraining_default_axes_keepdims +ReductionOpTest.ReduceSumTraining_axes_not_initializer +ReductionOpTest.ReduceSumTraining_do_not_keepdims +ReductionOpTest.ReduceSumTraining_neg_axis ReductionOpTest.ReduceProd_default_axes_keepdims ReductionOpTest.ReduceProd_default_axes_do_not_keep_dims ReductionOpTest.ReduceProd_do_not_keepdims @@ -147,10 +151,6 @@ MathOpTest.Less_multidiretional_broadcastBA MathOpTest.Greater_broadcastBA MathOpTest.Greater_multidiretional_broadcastAB MathOpTest.Greater_multidiretional_broadcastBA -MathOpTest.Equal_broadcastBA -MathOpTest.Equal_multidiretional_broadcastAB -MathOpTest.Equal_multidiretional_broadcastBA -MathOpTest.Equal_multidiretional_broadcastAB_bool MathOpTest.Pow_int64_float MathOpTest.Pow_int32_float MathOpTest.Pow_int64_double