diff --git a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc index d5829613fd..8d08e9047b 100644 --- a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc @@ -396,24 +396,24 @@ Status Min::ComputeInternal(OpKernelContext* context) const { //Greater op output tensor type is bool, so it cannot directly fit in the macros //for other elementwise ops -template -Status Greater::ComputeInternal(OpKernelContext* context) const { - typedef typename ToCudaType::MappedType CudaT; - const onnxruntime::Node& node = OpKernel::Node(); - const std::string& name = node.Name(); - - const Tensor* input0 = context->Input(0); - const Tensor* input1 = context->Input(1); - TensorShape output_shape; - ORT_RETURN_IF_ERROR(ComputeOutputShape(name, input0->Shape(), input1->Shape(), output_shape)); - size_t output_size = output_shape.Size(); - Tensor* output_tensor = context->Output(0, output_shape); - +template +Status CompareFunction::CompareMethod(OpKernelContext* context, void (*Impl_Compare)( + size_t output_rank_or_simple_broadcast, + const int64_t* lhs_padded_strides, + const CudaT* lhs_data, + const int64_t* rhs_padded_strides, + const CudaT* rhs_data, + const fast_divmod* fdm_output_strides, + const fast_divmod& fdm_H, + const fast_divmod& fdm_C, + CudaT* output_data, + size_t count)) const { BinaryElementwisePreparation prepare(this); - ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(input0, input1, output_tensor, &prepare)); - + Prepare(context, &prepare); + size_t output_size = prepare.output_tensor->Shape().Size(); IAllocatorUniquePtr output_buffer = GetScratchBuffer(output_size); - Impl_Greater( + ORT_RETURN_IF_ERROR(prepare.CopyToGpu()); + Impl_Compare( prepare.output_rank_or_simple_broadcast, prepare.lhs_padded_strides.GpuPtr(), reinterpret_cast(prepare.lhs_tensor->template Data()), @@ -423,48 +423,31 @@ Status Greater::ComputeInternal(OpKernelContext* context) const { prepare.fdm_H, prepare.fdm_C, reinterpret_cast(output_buffer.get()), - output_size); + prepare.output_tensor->Shape().Size()); Impl_Cast::MappedType>( reinterpret_cast(output_buffer.get()), - reinterpret_cast::MappedType*>(output_tensor->template MutableData()), + reinterpret_cast::MappedType*>(prepare.output_tensor->template MutableData()), output_size); + + return Status::OK(); +} + +//Greater op output tensor type is bool, so it cannot directly fit in the macros +//for other elementwise ops +template +Status Greater::ComputeInternal(OpKernelContext* context) const { + + this->CompareMethod(context, &Impl_Greater); + return Status::OK(); } template Status Equal::ComputeInternal(OpKernelContext* context) const { - typedef typename ToCudaType::MappedType CudaT; - const onnxruntime::Node& node = OpKernel::Node(); - const std::string& name = node.Name(); - const Tensor* input0 = context->Input(0); - const Tensor* input1 = context->Input(1); - TensorShape output_shape; - ORT_RETURN_IF_ERROR(ComputeOutputShape(name, input0->Shape(), input1->Shape(), output_shape)); - size_t output_size = output_shape.Size(); - Tensor* output_tensor = context->Output(0, output_shape); + this->CompareMethod(context, &Impl_Equal); - BinaryElementwisePreparation prepare(this); - ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(input0, input1, output_tensor, &prepare)); - - IAllocatorUniquePtr output_buffer = GetScratchBuffer(output_size); - Impl_Equal( - prepare.output_rank_or_simple_broadcast, - prepare.lhs_padded_strides.GpuPtr(), - reinterpret_cast(prepare.lhs_tensor->template Data()), - prepare.rhs_padded_strides.GpuPtr(), - reinterpret_cast(prepare.rhs_tensor->template Data()), - prepare.fdm_output_strides.GpuPtr(), - prepare.fdm_H, - prepare.fdm_C, - reinterpret_cast(output_buffer.get()), - output_size); - - Impl_Cast::MappedType>( - reinterpret_cast(output_buffer.get()), - reinterpret_cast::MappedType*>(output_tensor->template MutableData()), - output_size); return Status::OK(); } @@ -472,37 +455,9 @@ Status Equal::ComputeInternal(OpKernelContext* context) const { //for other elementwise ops template Status Less::ComputeInternal(OpKernelContext* context) const { - typedef typename ToCudaType::MappedType CudaT; - const onnxruntime::Node& node = OpKernel::Node(); - const std::string& name = node.Name(); - const Tensor* input0 = context->Input(0); - const Tensor* input1 = context->Input(1); - TensorShape output_shape; - ORT_RETURN_IF_ERROR(ComputeOutputShape(name, input0->Shape(), input1->Shape(), output_shape)); - size_t output_size = output_shape.Size(); - Tensor* output_tensor = context->Output(0, output_shape); + this->CompareMethod(context, &Impl_Less); - BinaryElementwisePreparation prepare(this); - ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(input0, input1, output_tensor, &prepare)); - - IAllocatorUniquePtr output_buffer = GetScratchBuffer(output_size); - Impl_Less( - prepare.output_rank_or_simple_broadcast, - prepare.lhs_padded_strides.GpuPtr(), - reinterpret_cast(prepare.lhs_tensor->template Data()), - prepare.rhs_padded_strides.GpuPtr(), - reinterpret_cast(prepare.rhs_tensor->template Data()), - prepare.fdm_output_strides.GpuPtr(), - prepare.fdm_H, - prepare.fdm_C, - reinterpret_cast(output_buffer.get()), - output_size); - - Impl_Cast::MappedType>( - reinterpret_cast(output_buffer.get()), - reinterpret_cast::MappedType*>(output_tensor->template MutableData()), - output_size); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.h index 79036c1e63..e662f99ebc 100644 --- a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.h +++ b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.h @@ -196,22 +196,6 @@ class Sum final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; }; -template -class Greater final : public CudaKernel { - public: - Greater(const OpKernelInfo& info) : CudaKernel(info) {} - - Status ComputeInternal(OpKernelContext* context) const override; -}; - -template -class Equal final : public CudaKernel { - public: - Equal(const OpKernelInfo& info) : CudaKernel(info) {} - - Status ComputeInternal(OpKernelContext* context) const override; -}; - template class Max final : public CudaKernel { @@ -231,10 +215,44 @@ class Min final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; }; -template -class Less final : public CudaKernel { +template +class CompareFunction : public BinaryElementwise { public: - Less(const OpKernelInfo& info) : CudaKernel(info) {} + CompareFunction(const OpKernelInfo& info) : BinaryElementwise(info) {} + + Status CompareMethod(OpKernelContext* context, void (*Impl_Compare)( + size_t output_rank_or_simple_broadcast, + const int64_t* lhs_padded_strides, + const CudaT* lhs_data, + const int64_t* rhs_padded_strides, + const CudaT* rhs_data, + const fast_divmod* fdm_output_strides, + const fast_divmod& fdm_H, + const fast_divmod& fdm_C, + CudaT* output_data, + size_t count)) const; +}; + +template +class Greater final : public CompareFunction::MappedType> { + public: + Greater(const OpKernelInfo& info) : CompareFunction::MappedType>(info) {} + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +template +class Equal final : public CompareFunction::MappedType> { + public: + Equal(const OpKernelInfo& info) : CompareFunction::MappedType>(info) {} + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +template +class Less final : public CompareFunction::MappedType> { + public: + Less(const OpKernelInfo& info) : CompareFunction::MappedType>(info) {} Status ComputeInternal(OpKernelContext* context) const override; }; diff --git a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.cu index 18b7d3514f..04b80bab3f 100644 --- a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.cu @@ -78,6 +78,7 @@ BINARY_OPS() // O: bool SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Add) +SPECIALIZED_BINARY_ELEMENTWISE_IMPL(Add, bool) SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Sub) SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Mul) SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Div) diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index d598ce2559..3d66b1205a 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -88,6 +88,38 @@ TEST(MathOpTest, Add_Broadcast_Axis) { test.Run(OpTester::ExpectResult::kExpectSuccess, ""); } +TEST(MathOpTest, Add_Broadcast_MultidirectionalAB) { + OpTester test("Add"); + + test.AddInput("A", {3, 1}, + {3.0f, + 2.0f, + 1.0f}); + test.AddInput("B", {3}, + {1.0f, 2.0f, 3.0f}); + test.AddOutput("C", {3, 3}, + {4.0f, 5.0f, 6.0f, + 3.0f, 4.0f, 5.0f, + 2.0f, 3.0f, 4.0f}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: got C with shape [3, 1] +} + +TEST(MathOpTest, Add_Broadcast_MultidirectionalBA) { + OpTester test("Add"); + + test.AddInput("A", {3}, + {1.0f, 2.0f, 3.0f}); + test.AddInput("B", {3, 1}, + {3.0f, + 2.0f, + 1.0f}); + test.AddOutput("C", {3, 3}, + {4.0f, 5.0f, 6.0f, + 3.0f, 4.0f, 5.0f, + 2.0f, 3.0f, 4.0f}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: got C with shape [3, 1] +} + TEST(MathOpTest, Add_Broadcast_0x0) { OpTester test("Add"); @@ -854,6 +886,37 @@ TEST(MathOpTest, Less_int64_Scalar1) { test.AddOutput("C", {4}, {false, true, false, true}); test.Run(); } +TEST(MathOpTest, Less_broadcastAB) { + OpTester test("Less", 9); + test.AddInput("A", {4, 2}, {10, 11, 12, 13, 14, 15, 16, 17}); + test.AddInput("B", {2}, {15, 7}); + test.AddOutput("C", {4, 2}, {true, false, true, false, true, false, false, false}); + test.Run(); +} + +TEST(MathOpTest, Less_broadcastBA) { + OpTester test("Less", 9); + test.AddInput("A", {2}, {15, 7}); + test.AddInput("B", {4, 2}, {10, 11, 12, 13, 14, 15, 16, 17}); + test.AddOutput("C", {4, 2}, {false, true, false, true, false, true, true, true}); + test.Run(); +} + +TEST(MathOpTest, Less_multidiretional_broadcastAB) { + OpTester test("Less", 9); + test.AddInput("A", {4, 1}, {10, 11, 12, 13}); + test.AddInput("B", {2}, {15, 7}); + test.AddOutput("C", {4, 2}, {true, false, true, false, true, false, true, false}); + test.Run(); +} + +TEST(MathOpTest, Less_multidiretional_broadcastBA) { + OpTester test("Less", 9); + test.AddInput("A", {2}, {15, 7}); + test.AddInput("B", {4, 1}, {10, 11, 12, 13}); + test.AddOutput("C", {4, 2}, {false, true, false, true, false, true, false, true}); + test.Run(); +} TEST(MathOpTest, Greater_7) { OpTester test("Greater"); @@ -891,6 +954,38 @@ TEST(MathOpTest, Greater_9_int64) { test.Run(); } +TEST(MathOpTest, Greater_broadcastAB) { + OpTester test("Greater", 9); + test.AddInput("A", {4, 2}, {10, 11, 12, 13, 14, 15, 16, 17}); + test.AddInput("B", {2}, {15, 7}); + test.AddOutput("C", {4, 2}, {false, true, false, true, false, true, true, true}); + test.Run(); +} + +TEST(MathOpTest, Greater_broadcastBA) { + OpTester test("Greater", 9); + test.AddInput("A", {2}, {15, 7}); + test.AddInput("B", {4, 2}, {10, 11, 12, 13, 14, 15, 16, 17}); + test.AddOutput("C", {4, 2}, {true, false, true, false, true, false, false, false}); + test.Run(); +} + +TEST(MathOpTest, Greater_multidiretional_broadcastAB) { + OpTester test("Greater", 9); + test.AddInput("A", {4, 1}, {10, 11, 12, 13}); + test.AddInput("B", {2}, {15, 7}); + test.AddOutput("C", {4, 2}, {false, true, false, true, false, true, false, true}); + test.Run(); +} + +TEST(MathOpTest, Greater_multidiretional_broadcastBA) { + OpTester test("Greater", 9); + test.AddInput("A", {2}, {15, 7}); + test.AddInput("B", {4, 1}, {10, 11, 12, 13}); + test.AddOutput("C", {4, 2}, {true, false, true, false, true, false, true, false}); + test.Run(); +} + TEST(MathOpTest, Equal_bool) { OpTester test("Equal"); std::vector dims{4}; @@ -943,6 +1038,46 @@ TEST(MathOpTest, Equal_float) { test.Run(); } +TEST(MathOpTest, Equal_broadcastAB) { + OpTester test("Equal"); + test.AddInput("A", {4, 2}, {1, 0, -1, -1, 1, 1, -1, 0}); + test.AddInput("B", {2}, {1, 1}); + test.AddOutput("C", {4, 2}, {true, false, false, false, true, true, false, false}); + test.Run(); +} + +TEST(MathOpTest, Equal_broadcastBA) { + OpTester test("Equal"); + test.AddInput("A", {2}, {1, 1}); + test.AddInput("B", {4, 2}, {1, 0, -1, -1, 1, 1, -1, 0}); + test.AddOutput("C", {4, 2}, {true, false, false, false, true, true, false, false}); + test.Run(); +} + +TEST(MathOpTest, Equal_multidiretional_broadcastAB) { + OpTester test("Equal"); + test.AddInput("A", {4, 1}, {1, 0, -1, -1}); + test.AddInput("B", {2}, {1, 1}); + test.AddOutput("C", {4, 2}, {true, true, false, false, false, false, false, false}); + test.Run(); +} + +TEST(MathOpTest, Equal_multidiretional_broadcastBA) { + OpTester test("Equal"); + test.AddInput("A", {2}, {1, 1}); + test.AddInput("B", {4, 1}, {1, 0, -1, -1}); + test.AddOutput("C", {4, 2}, {true, true, false, false, false, false, false, false}); + test.Run(); +} + +TEST(MathOpTest, Equal_multidiretional_broadcastAB_bool) { + OpTester test("Equal"); + test.AddInput("A", {4, 1}, {true, false, false, false}); + test.AddInput("B", {2}, {true, true}); + test.AddOutput("C", {4, 2}, {true, true, false, false, false, false, false, false}); + test.Run(); +} + TEST(MathOpTest, Mean_6) { OpTester test("Mean", 6); std::vector dims{3, 3};