From af4a41fd13012c4d2ac60a46b16161f23afd9fbb Mon Sep 17 00:00:00 2001 From: Du Li Date: Sun, 5 May 2019 20:27:43 -0700 Subject: [PATCH] change impl of CUDA Greater kernel to avoid data corruption. (#969) --- .../providers/cuda/math/binary_elementwise_ops.cc | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc index 8ab48ef15a..ac08044b14 100644 --- a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc @@ -3,6 +3,8 @@ #include "binary_elementwise_ops.h" #include "binary_elementwise_ops_impl.h" +#include "unary_elementwise_ops_impl.h" + using namespace onnxruntime::common; namespace onnxruntime { namespace cuda { @@ -313,10 +315,13 @@ Status Greater::ComputeInternal(OpKernelContext* context) const { const Tensor* input1 = context->Input(1); TensorShape output_shape; ORT_RETURN_IF_ERROR(ComputeOutputShape(name, input0->Shape(), input1->Shape(), output_shape)); + size_t output_size = output_shape.Size(); Tensor* output_tensor = context->Output(0, output_shape); BinaryElementwisePreparation prepare(this); ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(0, input0, input1, output_tensor, &prepare)); + + IAllocatorUniquePtr output_buffer = GetScratchBuffer(output_size); Impl_Greater( prepare.output_rank_or_simple_broadcast, prepare.lhs_padded_strides.GpuPtr(), @@ -326,9 +331,13 @@ Status Greater::ComputeInternal(OpKernelContext* context) const { prepare.fdm_output_strides.GpuPtr(), prepare.fdm_H, prepare.fdm_C, - reinterpret_cast(prepare.output_tensor->template MutableData()), - prepare.output_tensor->Shape().Size()); + reinterpret_cast(output_buffer.get()), + output_size); + Impl_Cast::MappedType>( + reinterpret_cast(output_buffer.get()), + reinterpret_cast::MappedType*>(output_tensor->template MutableData()), + output_size); return Status::OK(); }