change impl of CUDA Greater kernel to avoid data corruption. (#969)

This commit is contained in:
Du Li 2019-05-05 20:27:43 -07:00 committed by GitHub
parent f73ce305e9
commit af4a41fd13
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -3,6 +3,8 @@
#include "binary_elementwise_ops.h"
#include "binary_elementwise_ops_impl.h"
#include "unary_elementwise_ops_impl.h"
using namespace onnxruntime::common;
namespace onnxruntime {
namespace cuda {
@ -313,10 +315,13 @@ Status Greater<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* input1 = context->Input<Tensor>(1);
TensorShape output_shape;
ORT_RETURN_IF_ERROR(ComputeOutputShape(name, input0->Shape(), input1->Shape(), output_shape));
size_t output_size = output_shape.Size();
Tensor* output_tensor = context->Output(0, output_shape);
BinaryElementwisePreparation prepare(this);
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(0, input0, input1, output_tensor, &prepare));
IAllocatorUniquePtr<T> output_buffer = GetScratchBuffer<T>(output_size);
Impl_Greater<CudaT>(
prepare.output_rank_or_simple_broadcast,
prepare.lhs_padded_strides.GpuPtr(),
@ -326,9 +331,13 @@ Status Greater<T>::ComputeInternal(OpKernelContext* context) const {
prepare.fdm_output_strides.GpuPtr(),
prepare.fdm_H,
prepare.fdm_C,
reinterpret_cast<CudaT*>(prepare.output_tensor->template MutableData<bool>()),
prepare.output_tensor->Shape().Size());
reinterpret_cast<CudaT*>(output_buffer.get()),
output_size);
Impl_Cast<CudaT, ToCudaType<bool>::MappedType>(
reinterpret_cast<CudaT*>(output_buffer.get()),
reinterpret_cast<ToCudaType<bool>::MappedType*>(output_tensor->template MutableData<bool>()),
output_size);
return Status::OK();
}