mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
CUDA Equal Greater Less can't support multi-directional broadcast
Fix issue #1591 Root Cause: CUDA Equal Greater Less do not support multi-directional broadcast Fix: Add code to support the multi-directional broadcast Also add tests to cover more cases.
This commit is contained in:
parent
1a3ded6a7b
commit
e288b871ea
4 changed files with 204 additions and 95 deletions
|
|
@ -396,24 +396,24 @@ Status Min<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
|
||||
//Greater op output tensor type is bool, so it cannot directly fit in the macros
|
||||
//for other elementwise ops
|
||||
template <typename T>
|
||||
Status Greater<T>::ComputeInternal(OpKernelContext* context) const {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
const onnxruntime::Node& node = OpKernel::Node();
|
||||
const std::string& name = node.Name();
|
||||
|
||||
const Tensor* input0 = context->Input<Tensor>(0);
|
||||
const Tensor* input1 = context->Input<Tensor>(1);
|
||||
TensorShape output_shape;
|
||||
ORT_RETURN_IF_ERROR(ComputeOutputShape(name, input0->Shape(), input1->Shape(), output_shape));
|
||||
size_t output_size = output_shape.Size();
|
||||
Tensor* output_tensor = context->Output(0, output_shape);
|
||||
|
||||
template <typename T, typename CudaT>
|
||||
Status CompareFunction<T, CudaT>::CompareMethod(OpKernelContext* context, void (*Impl_Compare)(
|
||||
size_t output_rank_or_simple_broadcast,
|
||||
const int64_t* lhs_padded_strides,
|
||||
const CudaT* lhs_data,
|
||||
const int64_t* rhs_padded_strides,
|
||||
const CudaT* rhs_data,
|
||||
const fast_divmod* fdm_output_strides,
|
||||
const fast_divmod& fdm_H,
|
||||
const fast_divmod& fdm_C,
|
||||
CudaT* output_data,
|
||||
size_t count)) const {
|
||||
BinaryElementwisePreparation prepare(this);
|
||||
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(input0, input1, output_tensor, &prepare));
|
||||
|
||||
Prepare(context, &prepare);
|
||||
size_t output_size = prepare.output_tensor->Shape().Size();
|
||||
IAllocatorUniquePtr<T> output_buffer = GetScratchBuffer<T>(output_size);
|
||||
Impl_Greater<CudaT>(
|
||||
ORT_RETURN_IF_ERROR(prepare.CopyToGpu());
|
||||
Impl_Compare(
|
||||
prepare.output_rank_or_simple_broadcast,
|
||||
prepare.lhs_padded_strides.GpuPtr(),
|
||||
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()),
|
||||
|
|
@ -423,48 +423,31 @@ Status Greater<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
prepare.fdm_H,
|
||||
prepare.fdm_C,
|
||||
reinterpret_cast<CudaT*>(output_buffer.get()),
|
||||
output_size);
|
||||
prepare.output_tensor->Shape().Size());
|
||||
|
||||
Impl_Cast<CudaT, ToCudaType<bool>::MappedType>(
|
||||
reinterpret_cast<CudaT*>(output_buffer.get()),
|
||||
reinterpret_cast<ToCudaType<bool>::MappedType*>(output_tensor->template MutableData<bool>()),
|
||||
reinterpret_cast<ToCudaType<bool>::MappedType*>(prepare.output_tensor->template MutableData<bool>()),
|
||||
output_size);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//Greater op output tensor type is bool, so it cannot directly fit in the macros
|
||||
//for other elementwise ops
|
||||
template <typename T>
|
||||
Status Greater<T>::ComputeInternal(OpKernelContext* context) const {
|
||||
|
||||
this->CompareMethod(context, &Impl_Greater);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status Equal<T>::ComputeInternal(OpKernelContext* context) const {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
const onnxruntime::Node& node = OpKernel::Node();
|
||||
const std::string& name = node.Name();
|
||||
|
||||
const Tensor* input0 = context->Input<Tensor>(0);
|
||||
const Tensor* input1 = context->Input<Tensor>(1);
|
||||
TensorShape output_shape;
|
||||
ORT_RETURN_IF_ERROR(ComputeOutputShape(name, input0->Shape(), input1->Shape(), output_shape));
|
||||
size_t output_size = output_shape.Size();
|
||||
Tensor* output_tensor = context->Output(0, output_shape);
|
||||
this->CompareMethod(context, &Impl_Equal);
|
||||
|
||||
BinaryElementwisePreparation prepare(this);
|
||||
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(input0, input1, output_tensor, &prepare));
|
||||
|
||||
IAllocatorUniquePtr<T> output_buffer = GetScratchBuffer<T>(output_size);
|
||||
Impl_Equal<CudaT>(
|
||||
prepare.output_rank_or_simple_broadcast,
|
||||
prepare.lhs_padded_strides.GpuPtr(),
|
||||
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()),
|
||||
prepare.rhs_padded_strides.GpuPtr(),
|
||||
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()),
|
||||
prepare.fdm_output_strides.GpuPtr(),
|
||||
prepare.fdm_H,
|
||||
prepare.fdm_C,
|
||||
reinterpret_cast<CudaT*>(output_buffer.get()),
|
||||
output_size);
|
||||
|
||||
Impl_Cast<CudaT, ToCudaType<bool>::MappedType>(
|
||||
reinterpret_cast<CudaT*>(output_buffer.get()),
|
||||
reinterpret_cast<ToCudaType<bool>::MappedType*>(output_tensor->template MutableData<bool>()),
|
||||
output_size);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -472,37 +455,9 @@ Status Equal<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
//for other elementwise ops
|
||||
template <typename T>
|
||||
Status Less<T>::ComputeInternal(OpKernelContext* context) const {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
const onnxruntime::Node& node = OpKernel::Node();
|
||||
const std::string& name = node.Name();
|
||||
|
||||
const Tensor* input0 = context->Input<Tensor>(0);
|
||||
const Tensor* input1 = context->Input<Tensor>(1);
|
||||
TensorShape output_shape;
|
||||
ORT_RETURN_IF_ERROR(ComputeOutputShape(name, input0->Shape(), input1->Shape(), output_shape));
|
||||
size_t output_size = output_shape.Size();
|
||||
Tensor* output_tensor = context->Output(0, output_shape);
|
||||
this->CompareMethod(context, &Impl_Less);
|
||||
|
||||
BinaryElementwisePreparation prepare(this);
|
||||
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(input0, input1, output_tensor, &prepare));
|
||||
|
||||
IAllocatorUniquePtr<T> output_buffer = GetScratchBuffer<T>(output_size);
|
||||
Impl_Less<CudaT>(
|
||||
prepare.output_rank_or_simple_broadcast,
|
||||
prepare.lhs_padded_strides.GpuPtr(),
|
||||
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()),
|
||||
prepare.rhs_padded_strides.GpuPtr(),
|
||||
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()),
|
||||
prepare.fdm_output_strides.GpuPtr(),
|
||||
prepare.fdm_H,
|
||||
prepare.fdm_C,
|
||||
reinterpret_cast<CudaT*>(output_buffer.get()),
|
||||
output_size);
|
||||
|
||||
Impl_Cast<CudaT, ToCudaType<bool>::MappedType>(
|
||||
reinterpret_cast<CudaT*>(output_buffer.get()),
|
||||
reinterpret_cast<ToCudaType<bool>::MappedType*>(output_tensor->template MutableData<bool>()),
|
||||
output_size);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -196,22 +196,6 @@ class Sum final : public CudaKernel {
|
|||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Greater final : public CudaKernel {
|
||||
public:
|
||||
Greater(const OpKernelInfo& info) : CudaKernel(info) {}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Equal final : public CudaKernel {
|
||||
public:
|
||||
Equal(const OpKernelInfo& info) : CudaKernel(info) {}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
|
||||
template <typename T>
|
||||
class Max final : public CudaKernel {
|
||||
|
|
@ -231,10 +215,44 @@ class Min final : public CudaKernel {
|
|||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Less final : public CudaKernel {
|
||||
template <typename T, typename CudaT>
|
||||
class CompareFunction : public BinaryElementwise<ShouldBroadcast> {
|
||||
public:
|
||||
Less(const OpKernelInfo& info) : CudaKernel(info) {}
|
||||
CompareFunction(const OpKernelInfo& info) : BinaryElementwise(info) {}
|
||||
|
||||
Status CompareMethod(OpKernelContext* context, void (*Impl_Compare)(
|
||||
size_t output_rank_or_simple_broadcast,
|
||||
const int64_t* lhs_padded_strides,
|
||||
const CudaT* lhs_data,
|
||||
const int64_t* rhs_padded_strides,
|
||||
const CudaT* rhs_data,
|
||||
const fast_divmod* fdm_output_strides,
|
||||
const fast_divmod& fdm_H,
|
||||
const fast_divmod& fdm_C,
|
||||
CudaT* output_data,
|
||||
size_t count)) const;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Greater final : public CompareFunction<T, typename ToCudaType<T>::MappedType> {
|
||||
public:
|
||||
Greater(const OpKernelInfo& info) : CompareFunction<T, typename ToCudaType<T>::MappedType>(info) {}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Equal final : public CompareFunction<T, typename ToCudaType<T>::MappedType> {
|
||||
public:
|
||||
Equal(const OpKernelInfo& info) : CompareFunction<T, typename ToCudaType<T>::MappedType>(info) {}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Less final : public CompareFunction<T, typename ToCudaType<T>::MappedType> {
|
||||
public:
|
||||
Less(const OpKernelInfo& info) : CompareFunction<T, typename ToCudaType<T>::MappedType>(info) {}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -78,6 +78,7 @@ BINARY_OPS()
|
|||
// O: bool
|
||||
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Add)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(Add, bool)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Sub)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Mul)
|
||||
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Div)
|
||||
|
|
|
|||
|
|
@ -88,6 +88,38 @@ TEST(MathOpTest, Add_Broadcast_Axis) {
|
|||
test.Run(OpTester::ExpectResult::kExpectSuccess, "");
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Add_Broadcast_MultidirectionalAB) {
|
||||
OpTester test("Add");
|
||||
|
||||
test.AddInput<float>("A", {3, 1},
|
||||
{3.0f,
|
||||
2.0f,
|
||||
1.0f});
|
||||
test.AddInput<float>("B", {3},
|
||||
{1.0f, 2.0f, 3.0f});
|
||||
test.AddOutput<float>("C", {3, 3},
|
||||
{4.0f, 5.0f, 6.0f,
|
||||
3.0f, 4.0f, 5.0f,
|
||||
2.0f, 3.0f, 4.0f});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: got C with shape [3, 1]
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Add_Broadcast_MultidirectionalBA) {
|
||||
OpTester test("Add");
|
||||
|
||||
test.AddInput<float>("A", {3},
|
||||
{1.0f, 2.0f, 3.0f});
|
||||
test.AddInput<float>("B", {3, 1},
|
||||
{3.0f,
|
||||
2.0f,
|
||||
1.0f});
|
||||
test.AddOutput<float>("C", {3, 3},
|
||||
{4.0f, 5.0f, 6.0f,
|
||||
3.0f, 4.0f, 5.0f,
|
||||
2.0f, 3.0f, 4.0f});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: got C with shape [3, 1]
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Add_Broadcast_0x0) {
|
||||
OpTester test("Add");
|
||||
|
||||
|
|
@ -854,6 +886,37 @@ TEST(MathOpTest, Less_int64_Scalar1) {
|
|||
test.AddOutput<bool>("C", {4}, {false, true, false, true});
|
||||
test.Run();
|
||||
}
|
||||
TEST(MathOpTest, Less_broadcastAB) {
|
||||
OpTester test("Less", 9);
|
||||
test.AddInput<int32_t>("A", {4, 2}, {10, 11, 12, 13, 14, 15, 16, 17});
|
||||
test.AddInput<int32_t>("B", {2}, {15, 7});
|
||||
test.AddOutput<bool>("C", {4, 2}, {true, false, true, false, true, false, false, false});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Less_broadcastBA) {
|
||||
OpTester test("Less", 9);
|
||||
test.AddInput<int32_t>("A", {2}, {15, 7});
|
||||
test.AddInput<int32_t>("B", {4, 2}, {10, 11, 12, 13, 14, 15, 16, 17});
|
||||
test.AddOutput<bool>("C", {4, 2}, {false, true, false, true, false, true, true, true});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Less_multidiretional_broadcastAB) {
|
||||
OpTester test("Less", 9);
|
||||
test.AddInput<int32_t>("A", {4, 1}, {10, 11, 12, 13});
|
||||
test.AddInput<int32_t>("B", {2}, {15, 7});
|
||||
test.AddOutput<bool>("C", {4, 2}, {true, false, true, false, true, false, true, false});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Less_multidiretional_broadcastBA) {
|
||||
OpTester test("Less", 9);
|
||||
test.AddInput<int32_t>("A", {2}, {15, 7});
|
||||
test.AddInput<int32_t>("B", {4, 1}, {10, 11, 12, 13});
|
||||
test.AddOutput<bool>("C", {4, 2}, {false, true, false, true, false, true, false, true});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Greater_7) {
|
||||
OpTester test("Greater");
|
||||
|
|
@ -891,6 +954,38 @@ TEST(MathOpTest, Greater_9_int64) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Greater_broadcastAB) {
|
||||
OpTester test("Greater", 9);
|
||||
test.AddInput<int32_t>("A", {4, 2}, {10, 11, 12, 13, 14, 15, 16, 17});
|
||||
test.AddInput<int32_t>("B", {2}, {15, 7});
|
||||
test.AddOutput<bool>("C", {4, 2}, {false, true, false, true, false, true, true, true});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Greater_broadcastBA) {
|
||||
OpTester test("Greater", 9);
|
||||
test.AddInput<int32_t>("A", {2}, {15, 7});
|
||||
test.AddInput<int32_t>("B", {4, 2}, {10, 11, 12, 13, 14, 15, 16, 17});
|
||||
test.AddOutput<bool>("C", {4, 2}, {true, false, true, false, true, false, false, false});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Greater_multidiretional_broadcastAB) {
|
||||
OpTester test("Greater", 9);
|
||||
test.AddInput<int32_t>("A", {4, 1}, {10, 11, 12, 13});
|
||||
test.AddInput<int32_t>("B", {2}, {15, 7});
|
||||
test.AddOutput<bool>("C", {4, 2}, {false, true, false, true, false, true, false, true});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Greater_multidiretional_broadcastBA) {
|
||||
OpTester test("Greater", 9);
|
||||
test.AddInput<int32_t>("A", {2}, {15, 7});
|
||||
test.AddInput<int32_t>("B", {4, 1}, {10, 11, 12, 13});
|
||||
test.AddOutput<bool>("C", {4, 2}, {true, false, true, false, true, false, true, false});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Equal_bool) {
|
||||
OpTester test("Equal");
|
||||
std::vector<int64_t> dims{4};
|
||||
|
|
@ -943,6 +1038,46 @@ TEST(MathOpTest, Equal_float) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Equal_broadcastAB) {
|
||||
OpTester test("Equal");
|
||||
test.AddInput<int32_t>("A", {4, 2}, {1, 0, -1, -1, 1, 1, -1, 0});
|
||||
test.AddInput<int32_t>("B", {2}, {1, 1});
|
||||
test.AddOutput<bool>("C", {4, 2}, {true, false, false, false, true, true, false, false});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Equal_broadcastBA) {
|
||||
OpTester test("Equal");
|
||||
test.AddInput<int32_t>("A", {2}, {1, 1});
|
||||
test.AddInput<int32_t>("B", {4, 2}, {1, 0, -1, -1, 1, 1, -1, 0});
|
||||
test.AddOutput<bool>("C", {4, 2}, {true, false, false, false, true, true, false, false});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Equal_multidiretional_broadcastAB) {
|
||||
OpTester test("Equal");
|
||||
test.AddInput<int32_t>("A", {4, 1}, {1, 0, -1, -1});
|
||||
test.AddInput<int32_t>("B", {2}, {1, 1});
|
||||
test.AddOutput<bool>("C", {4, 2}, {true, true, false, false, false, false, false, false});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Equal_multidiretional_broadcastBA) {
|
||||
OpTester test("Equal");
|
||||
test.AddInput<int32_t>("A", {2}, {1, 1});
|
||||
test.AddInput<int32_t>("B", {4, 1}, {1, 0, -1, -1});
|
||||
test.AddOutput<bool>("C", {4, 2}, {true, true, false, false, false, false, false, false});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Equal_multidiretional_broadcastAB_bool) {
|
||||
OpTester test("Equal");
|
||||
test.AddInput<bool>("A", {4, 1}, {true, false, false, false});
|
||||
test.AddInput<bool>("B", {2}, {true, true});
|
||||
test.AddOutput<bool>("C", {4, 2}, {true, true, false, false, false, false, false, false});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Mean_6) {
|
||||
OpTester test("Mean", 6);
|
||||
std::vector<int64_t> dims{3, 3};
|
||||
|
|
|
|||
Loading…
Reference in a new issue