diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_functions.cc b/onnxruntime/core/providers/cuda/reduction/reduction_functions.cc index 955df6d9a1..09c38c2def 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_functions.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_functions.cc @@ -105,8 +105,42 @@ ApplicableMatrixReduction get_applicable_matrix_reduction( return ApplicableMatrixReduction::None; } - const auto rank = gsl::narrow(dims.size()); - const auto min_and_max_axes = GetMinAndMaxContiguousAxes(rank, dims, original_axes); + + // Remove all dims with value 1. This can help to optimize case like: + // dims=[2,3,1,4,1,5] and axes=[0,2,4], which is same as dims=[2,3,4,5] and axes=[0]. + std::vector new_dims; + std::vector new_axes; + const auto original_rank = gsl::narrow(dims.size()); + std::set original_axes_set; + for (const auto axis : original_axes) { + original_axes_set.insert(HandleNegativeAxis(axis, original_rank)); + } + + int64_t new_axis = 0; + for (size_t i = 0; i < dims.size(); i++) { + if (dims[i] != 1) { + new_dims.emplace_back(dims[i]); + if (original_axes_set.find(gsl::narrow(i)) != original_axes_set.end()) { + new_axes.emplace_back(new_axis); + } + new_axis++; + } + } + + // Empty axes means reduce all dimensions, which has different meaning, + // so add a new dim to the end if all original axes are on dims with value 1. + if (!original_axes.empty() && new_axes.empty()) { + new_dims.emplace_back(1); + new_axes.emplace_back(new_axis); + } + + // If all dims are value 1, make sure it's not empty by adding a new dim. + if (!dims.empty() && new_dims.empty()) { + new_dims.emplace_back(1); + } + + const auto rank = gsl::narrow(new_dims.size()); + const auto min_and_max_axes = GetMinAndMaxContiguousAxes(rank, new_dims, new_axes); if (!min_and_max_axes.has_value()) { return ApplicableMatrixReduction::None; } @@ -127,7 +161,7 @@ ApplicableMatrixReduction get_applicable_matrix_reduction( // the axis index right after the last flattened into matrix rows const int64_t m_end_axis = axes_from_beginning ? max_axis + 1 : min_axis; - const TensorShape& shape = TensorShape::ReinterpretBaseType(dims); + const TensorShape& shape = TensorShape::ReinterpretBaseType(new_dims); const auto m = shape.SizeToDimension(m_end_axis); const auto n = shape.SizeFromDimension(m_end_axis); diff --git a/orttraining/orttraining/test/training_ops/cuda/reduce_sum_test.cc b/orttraining/orttraining/test/training_ops/cuda/reduce_sum_test.cc index dd76668d95..fe1b9ebfa0 100644 --- a/orttraining/orttraining/test/training_ops/cuda/reduce_sum_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/reduce_sum_test.cc @@ -19,7 +19,7 @@ static void TestReduceSum(const std::vector& X_dims, double per_sample_tolerance = 2e-4, double relative_per_sample_tolerance = 2e-4) { CompareOpTester test("ReduceSum"); - test.AddAttribute("axes", axes); + if (!axes.empty()) test.AddAttribute("axes", axes); test.AddAttribute("keepdims", int64_t(keepdims)); // create rand inputs @@ -38,66 +38,79 @@ static void TestReduceSum(const std::vector& X_dims, TEST(CudaKernelTest, ReduceSum_Scalar) { std::vector X_dims{1}; - std::vector Y_dims{}; std::vector axes{0}; - bool keepdims = false; - TestReduceSum(X_dims, Y_dims, axes, keepdims); + TestReduceSum(X_dims, {}, axes, false); + TestReduceSum(X_dims, {1}, axes, true); } TEST(CudaKernelTest, ReduceSum_2DtoLastDim) { std::vector X_dims{16, 2}; - std::vector Y_dims{2}; std::vector axes{0}; - bool keepdims = false; - TestReduceSum(X_dims, Y_dims, axes, keepdims); + TestReduceSum(X_dims, {2}, axes, false); + TestReduceSum(X_dims, {1, 2}, axes, true); } TEST(CudaKernelTest, ReduceSum_SmallTensor) { std::vector X_dims{2, 128, 128}; - std::vector Y_dims{128}; std::vector axes{0, 1}; - bool keepdims = false; - TestReduceSum(X_dims, Y_dims, axes, keepdims); + TestReduceSum(X_dims, {128}, axes, false); + TestReduceSum(X_dims, {1, 1, 128}, axes, true); } TEST(CudaKernelTest, ReduceSum_MidTensor) { std::vector X_dims{2, 512, 3072}; - std::vector Y_dims{3072}; std::vector axes{0, 1}; - bool keepdims = false; - TestReduceSum(X_dims, Y_dims, axes, keepdims); + TestReduceSum(X_dims, {3072}, axes, false); + TestReduceSum(X_dims, {1, 1, 3072}, axes, true); } TEST(CudaKernelTest, ReduceSum_LargeTensor) { std::vector X_dims{4, 512, 30528}; - std::vector Y_dims{30528}; std::vector axes{0, 1}; - bool keepdims = false; - TestReduceSum(X_dims, Y_dims, axes, keepdims); + TestReduceSum(X_dims, {30528}, axes, false); + TestReduceSum(X_dims, {1, 1, 30528}, axes, true); } TEST(CudaKernelTest, ReduceSum_SmallTensorTrailingAxes) { std::vector X_dims{128, 2, 128}; - std::vector Y_dims{128}; std::vector axes{1, 2}; - bool keepdims = false; - TestReduceSum(X_dims, Y_dims, axes, keepdims); + TestReduceSum(X_dims, {128}, axes, false); + TestReduceSum(X_dims, {128, 1, 1}, axes, true); } TEST(CudaKernelTest, ReduceSum_MidTensorTrailingAxes) { std::vector X_dims{3072, 2, 512}; - std::vector Y_dims{3072}; std::vector axes{1, 2}; - bool keepdims = false; - TestReduceSum(X_dims, Y_dims, axes, keepdims); + TestReduceSum(X_dims, {3072}, axes, false); + TestReduceSum(X_dims, {3072, 1, 1}, axes, true); } TEST(CudaKernelTest, ReduceSum_LargeTensorTrailingAxes) { std::vector X_dims{30528, 4, 512}; - std::vector Y_dims{30528}; std::vector axes{1, 2}; - bool keepdims = false; - TestReduceSum(X_dims, Y_dims, axes, keepdims); + TestReduceSum(X_dims, {30528}, axes, false); + TestReduceSum(X_dims, {30528, 1, 1}, axes, true); +} + +TEST(CudaKernelTest, ReduceSum_OneDimsOptimization) { + std::vector X_dims{2, 3, 1, 4, 1, 5}; + std::vector axes{0, 2, 4}; + TestReduceSum(X_dims, {3, 4, 5}, axes, false); + TestReduceSum(X_dims, {1, 3, 1, 4, 1, 5}, axes, true); +} + +TEST(CudaKernelTest, ReduceSum_ReduceOnOneDims) { + std::vector X_dims{2, 1, 1}; + std::vector axes{1, 2}; + TestReduceSum(X_dims, {2}, axes, false); + TestReduceSum(X_dims, {2, 1, 1}, axes, true); +} + +TEST(CudaKernelTest, ReduceSum_AllOneDims) { + std::vector X_dims{1, 1}; + std::vector axes{}; + TestReduceSum(X_dims, {}, axes, false); + TestReduceSum(X_dims, {1, 1}, axes, true); } } // namespace test