CUDA Equal Greater Less can't support multi-directional broadcast

Fix issue #1591

Root Cause:
CUDA Equal Greater Less do not support multi-directional broadcast

Fix:
Add code to support the multi-directional broadcast
Also add tests to cover more cases.
This commit is contained in:
Hector Li 2019-09-23 22:21:52 -07:00 committed by GitHub
parent 1a3ded6a7b
commit e288b871ea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 204 additions and 95 deletions

View file

@ -396,24 +396,24 @@ Status Min<T>::ComputeInternal(OpKernelContext* context) const {
//Greater op output tensor type is bool, so it cannot directly fit in the macros
//for other elementwise ops
template <typename T>
Status Greater<T>::ComputeInternal(OpKernelContext* context) const {
typedef typename ToCudaType<T>::MappedType CudaT;
const onnxruntime::Node& node = OpKernel::Node();
const std::string& name = node.Name();
const Tensor* input0 = context->Input<Tensor>(0);
const Tensor* input1 = context->Input<Tensor>(1);
TensorShape output_shape;
ORT_RETURN_IF_ERROR(ComputeOutputShape(name, input0->Shape(), input1->Shape(), output_shape));
size_t output_size = output_shape.Size();
Tensor* output_tensor = context->Output(0, output_shape);
template <typename T, typename CudaT>
Status CompareFunction<T, CudaT>::CompareMethod(OpKernelContext* context, void (*Impl_Compare)(
size_t output_rank_or_simple_broadcast,
const int64_t* lhs_padded_strides,
const CudaT* lhs_data,
const int64_t* rhs_padded_strides,
const CudaT* rhs_data,
const fast_divmod* fdm_output_strides,
const fast_divmod& fdm_H,
const fast_divmod& fdm_C,
CudaT* output_data,
size_t count)) const {
BinaryElementwisePreparation prepare(this);
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(input0, input1, output_tensor, &prepare));
Prepare(context, &prepare);
size_t output_size = prepare.output_tensor->Shape().Size();
IAllocatorUniquePtr<T> output_buffer = GetScratchBuffer<T>(output_size);
Impl_Greater<CudaT>(
ORT_RETURN_IF_ERROR(prepare.CopyToGpu());
Impl_Compare(
prepare.output_rank_or_simple_broadcast,
prepare.lhs_padded_strides.GpuPtr(),
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()),
@ -423,48 +423,31 @@ Status Greater<T>::ComputeInternal(OpKernelContext* context) const {
prepare.fdm_H,
prepare.fdm_C,
reinterpret_cast<CudaT*>(output_buffer.get()),
output_size);
prepare.output_tensor->Shape().Size());
Impl_Cast<CudaT, ToCudaType<bool>::MappedType>(
reinterpret_cast<CudaT*>(output_buffer.get()),
reinterpret_cast<ToCudaType<bool>::MappedType*>(output_tensor->template MutableData<bool>()),
reinterpret_cast<ToCudaType<bool>::MappedType*>(prepare.output_tensor->template MutableData<bool>()),
output_size);
return Status::OK();
}
//Greater op output tensor type is bool, so it cannot directly fit in the macros
//for other elementwise ops
template <typename T>
Status Greater<T>::ComputeInternal(OpKernelContext* context) const {
this->CompareMethod(context, &Impl_Greater);
return Status::OK();
}
template <typename T>
Status Equal<T>::ComputeInternal(OpKernelContext* context) const {
typedef typename ToCudaType<T>::MappedType CudaT;
const onnxruntime::Node& node = OpKernel::Node();
const std::string& name = node.Name();
const Tensor* input0 = context->Input<Tensor>(0);
const Tensor* input1 = context->Input<Tensor>(1);
TensorShape output_shape;
ORT_RETURN_IF_ERROR(ComputeOutputShape(name, input0->Shape(), input1->Shape(), output_shape));
size_t output_size = output_shape.Size();
Tensor* output_tensor = context->Output(0, output_shape);
this->CompareMethod(context, &Impl_Equal);
BinaryElementwisePreparation prepare(this);
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(input0, input1, output_tensor, &prepare));
IAllocatorUniquePtr<T> output_buffer = GetScratchBuffer<T>(output_size);
Impl_Equal<CudaT>(
prepare.output_rank_or_simple_broadcast,
prepare.lhs_padded_strides.GpuPtr(),
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()),
prepare.rhs_padded_strides.GpuPtr(),
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()),
prepare.fdm_output_strides.GpuPtr(),
prepare.fdm_H,
prepare.fdm_C,
reinterpret_cast<CudaT*>(output_buffer.get()),
output_size);
Impl_Cast<CudaT, ToCudaType<bool>::MappedType>(
reinterpret_cast<CudaT*>(output_buffer.get()),
reinterpret_cast<ToCudaType<bool>::MappedType*>(output_tensor->template MutableData<bool>()),
output_size);
return Status::OK();
}
@ -472,37 +455,9 @@ Status Equal<T>::ComputeInternal(OpKernelContext* context) const {
//for other elementwise ops
template <typename T>
Status Less<T>::ComputeInternal(OpKernelContext* context) const {
typedef typename ToCudaType<T>::MappedType CudaT;
const onnxruntime::Node& node = OpKernel::Node();
const std::string& name = node.Name();
const Tensor* input0 = context->Input<Tensor>(0);
const Tensor* input1 = context->Input<Tensor>(1);
TensorShape output_shape;
ORT_RETURN_IF_ERROR(ComputeOutputShape(name, input0->Shape(), input1->Shape(), output_shape));
size_t output_size = output_shape.Size();
Tensor* output_tensor = context->Output(0, output_shape);
this->CompareMethod(context, &Impl_Less);
BinaryElementwisePreparation prepare(this);
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(input0, input1, output_tensor, &prepare));
IAllocatorUniquePtr<T> output_buffer = GetScratchBuffer<T>(output_size);
Impl_Less<CudaT>(
prepare.output_rank_or_simple_broadcast,
prepare.lhs_padded_strides.GpuPtr(),
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()),
prepare.rhs_padded_strides.GpuPtr(),
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()),
prepare.fdm_output_strides.GpuPtr(),
prepare.fdm_H,
prepare.fdm_C,
reinterpret_cast<CudaT*>(output_buffer.get()),
output_size);
Impl_Cast<CudaT, ToCudaType<bool>::MappedType>(
reinterpret_cast<CudaT*>(output_buffer.get()),
reinterpret_cast<ToCudaType<bool>::MappedType*>(output_tensor->template MutableData<bool>()),
output_size);
return Status::OK();
}

View file

@ -196,22 +196,6 @@ class Sum final : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Greater final : public CudaKernel {
public:
Greater(const OpKernelInfo& info) : CudaKernel(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Equal final : public CudaKernel {
public:
Equal(const OpKernelInfo& info) : CudaKernel(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Max final : public CudaKernel {
@ -231,10 +215,44 @@ class Min final : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Less final : public CudaKernel {
template <typename T, typename CudaT>
class CompareFunction : public BinaryElementwise<ShouldBroadcast> {
public:
Less(const OpKernelInfo& info) : CudaKernel(info) {}
CompareFunction(const OpKernelInfo& info) : BinaryElementwise(info) {}
Status CompareMethod(OpKernelContext* context, void (*Impl_Compare)(
size_t output_rank_or_simple_broadcast,
const int64_t* lhs_padded_strides,
const CudaT* lhs_data,
const int64_t* rhs_padded_strides,
const CudaT* rhs_data,
const fast_divmod* fdm_output_strides,
const fast_divmod& fdm_H,
const fast_divmod& fdm_C,
CudaT* output_data,
size_t count)) const;
};
template <typename T>
class Greater final : public CompareFunction<T, typename ToCudaType<T>::MappedType> {
public:
Greater(const OpKernelInfo& info) : CompareFunction<T, typename ToCudaType<T>::MappedType>(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Equal final : public CompareFunction<T, typename ToCudaType<T>::MappedType> {
public:
Equal(const OpKernelInfo& info) : CompareFunction<T, typename ToCudaType<T>::MappedType>(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Less final : public CompareFunction<T, typename ToCudaType<T>::MappedType> {
public:
Less(const OpKernelInfo& info) : CompareFunction<T, typename ToCudaType<T>::MappedType>(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};

View file

@ -78,6 +78,7 @@ BINARY_OPS()
// O: bool
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Add)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(Add, bool)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Sub)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Mul)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Div)

View file

@ -88,6 +88,38 @@ TEST(MathOpTest, Add_Broadcast_Axis) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "");
}
TEST(MathOpTest, Add_Broadcast_MultidirectionalAB) {
OpTester test("Add");
test.AddInput<float>("A", {3, 1},
{3.0f,
2.0f,
1.0f});
test.AddInput<float>("B", {3},
{1.0f, 2.0f, 3.0f});
test.AddOutput<float>("C", {3, 3},
{4.0f, 5.0f, 6.0f,
3.0f, 4.0f, 5.0f,
2.0f, 3.0f, 4.0f});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: got C with shape [3, 1]
}
TEST(MathOpTest, Add_Broadcast_MultidirectionalBA) {
OpTester test("Add");
test.AddInput<float>("A", {3},
{1.0f, 2.0f, 3.0f});
test.AddInput<float>("B", {3, 1},
{3.0f,
2.0f,
1.0f});
test.AddOutput<float>("C", {3, 3},
{4.0f, 5.0f, 6.0f,
3.0f, 4.0f, 5.0f,
2.0f, 3.0f, 4.0f});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: got C with shape [3, 1]
}
TEST(MathOpTest, Add_Broadcast_0x0) {
OpTester test("Add");
@ -854,6 +886,37 @@ TEST(MathOpTest, Less_int64_Scalar1) {
test.AddOutput<bool>("C", {4}, {false, true, false, true});
test.Run();
}
TEST(MathOpTest, Less_broadcastAB) {
OpTester test("Less", 9);
test.AddInput<int32_t>("A", {4, 2}, {10, 11, 12, 13, 14, 15, 16, 17});
test.AddInput<int32_t>("B", {2}, {15, 7});
test.AddOutput<bool>("C", {4, 2}, {true, false, true, false, true, false, false, false});
test.Run();
}
TEST(MathOpTest, Less_broadcastBA) {
OpTester test("Less", 9);
test.AddInput<int32_t>("A", {2}, {15, 7});
test.AddInput<int32_t>("B", {4, 2}, {10, 11, 12, 13, 14, 15, 16, 17});
test.AddOutput<bool>("C", {4, 2}, {false, true, false, true, false, true, true, true});
test.Run();
}
TEST(MathOpTest, Less_multidiretional_broadcastAB) {
OpTester test("Less", 9);
test.AddInput<int32_t>("A", {4, 1}, {10, 11, 12, 13});
test.AddInput<int32_t>("B", {2}, {15, 7});
test.AddOutput<bool>("C", {4, 2}, {true, false, true, false, true, false, true, false});
test.Run();
}
TEST(MathOpTest, Less_multidiretional_broadcastBA) {
OpTester test("Less", 9);
test.AddInput<int32_t>("A", {2}, {15, 7});
test.AddInput<int32_t>("B", {4, 1}, {10, 11, 12, 13});
test.AddOutput<bool>("C", {4, 2}, {false, true, false, true, false, true, false, true});
test.Run();
}
TEST(MathOpTest, Greater_7) {
OpTester test("Greater");
@ -891,6 +954,38 @@ TEST(MathOpTest, Greater_9_int64) {
test.Run();
}
TEST(MathOpTest, Greater_broadcastAB) {
OpTester test("Greater", 9);
test.AddInput<int32_t>("A", {4, 2}, {10, 11, 12, 13, 14, 15, 16, 17});
test.AddInput<int32_t>("B", {2}, {15, 7});
test.AddOutput<bool>("C", {4, 2}, {false, true, false, true, false, true, true, true});
test.Run();
}
TEST(MathOpTest, Greater_broadcastBA) {
OpTester test("Greater", 9);
test.AddInput<int32_t>("A", {2}, {15, 7});
test.AddInput<int32_t>("B", {4, 2}, {10, 11, 12, 13, 14, 15, 16, 17});
test.AddOutput<bool>("C", {4, 2}, {true, false, true, false, true, false, false, false});
test.Run();
}
TEST(MathOpTest, Greater_multidiretional_broadcastAB) {
OpTester test("Greater", 9);
test.AddInput<int32_t>("A", {4, 1}, {10, 11, 12, 13});
test.AddInput<int32_t>("B", {2}, {15, 7});
test.AddOutput<bool>("C", {4, 2}, {false, true, false, true, false, true, false, true});
test.Run();
}
TEST(MathOpTest, Greater_multidiretional_broadcastBA) {
OpTester test("Greater", 9);
test.AddInput<int32_t>("A", {2}, {15, 7});
test.AddInput<int32_t>("B", {4, 1}, {10, 11, 12, 13});
test.AddOutput<bool>("C", {4, 2}, {true, false, true, false, true, false, true, false});
test.Run();
}
TEST(MathOpTest, Equal_bool) {
OpTester test("Equal");
std::vector<int64_t> dims{4};
@ -943,6 +1038,46 @@ TEST(MathOpTest, Equal_float) {
test.Run();
}
TEST(MathOpTest, Equal_broadcastAB) {
OpTester test("Equal");
test.AddInput<int32_t>("A", {4, 2}, {1, 0, -1, -1, 1, 1, -1, 0});
test.AddInput<int32_t>("B", {2}, {1, 1});
test.AddOutput<bool>("C", {4, 2}, {true, false, false, false, true, true, false, false});
test.Run();
}
TEST(MathOpTest, Equal_broadcastBA) {
OpTester test("Equal");
test.AddInput<int32_t>("A", {2}, {1, 1});
test.AddInput<int32_t>("B", {4, 2}, {1, 0, -1, -1, 1, 1, -1, 0});
test.AddOutput<bool>("C", {4, 2}, {true, false, false, false, true, true, false, false});
test.Run();
}
TEST(MathOpTest, Equal_multidiretional_broadcastAB) {
OpTester test("Equal");
test.AddInput<int32_t>("A", {4, 1}, {1, 0, -1, -1});
test.AddInput<int32_t>("B", {2}, {1, 1});
test.AddOutput<bool>("C", {4, 2}, {true, true, false, false, false, false, false, false});
test.Run();
}
TEST(MathOpTest, Equal_multidiretional_broadcastBA) {
OpTester test("Equal");
test.AddInput<int32_t>("A", {2}, {1, 1});
test.AddInput<int32_t>("B", {4, 1}, {1, 0, -1, -1});
test.AddOutput<bool>("C", {4, 2}, {true, true, false, false, false, false, false, false});
test.Run();
}
TEST(MathOpTest, Equal_multidiretional_broadcastAB_bool) {
OpTester test("Equal");
test.AddInput<bool>("A", {4, 1}, {true, false, false, false});
test.AddInput<bool>("B", {2}, {true, true});
test.AddOutput<bool>("C", {4, 2}, {true, true, false, false, false, false, false, false});
test.Run();
}
TEST(MathOpTest, Mean_6) {
OpTester test("Mean", 6);
std::vector<int64_t> dims{3, 3};