mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
change impl of CUDA Greater kernel to avoid data corruption. (#969)
This commit is contained in:
parent
f73ce305e9
commit
af4a41fd13
1 changed files with 11 additions and 2 deletions
|
|
@ -3,6 +3,8 @@
|
|||
|
||||
#include "binary_elementwise_ops.h"
|
||||
#include "binary_elementwise_ops_impl.h"
|
||||
#include "unary_elementwise_ops_impl.h"
|
||||
|
||||
using namespace onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
|
@ -313,10 +315,13 @@ Status Greater<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
const Tensor* input1 = context->Input<Tensor>(1);
|
||||
TensorShape output_shape;
|
||||
ORT_RETURN_IF_ERROR(ComputeOutputShape(name, input0->Shape(), input1->Shape(), output_shape));
|
||||
size_t output_size = output_shape.Size();
|
||||
Tensor* output_tensor = context->Output(0, output_shape);
|
||||
|
||||
BinaryElementwisePreparation prepare(this);
|
||||
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(0, input0, input1, output_tensor, &prepare));
|
||||
|
||||
IAllocatorUniquePtr<T> output_buffer = GetScratchBuffer<T>(output_size);
|
||||
Impl_Greater<CudaT>(
|
||||
prepare.output_rank_or_simple_broadcast,
|
||||
prepare.lhs_padded_strides.GpuPtr(),
|
||||
|
|
@ -326,9 +331,13 @@ Status Greater<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
prepare.fdm_output_strides.GpuPtr(),
|
||||
prepare.fdm_H,
|
||||
prepare.fdm_C,
|
||||
reinterpret_cast<CudaT*>(prepare.output_tensor->template MutableData<bool>()),
|
||||
prepare.output_tensor->Shape().Size());
|
||||
reinterpret_cast<CudaT*>(output_buffer.get()),
|
||||
output_size);
|
||||
|
||||
Impl_Cast<CudaT, ToCudaType<bool>::MappedType>(
|
||||
reinterpret_cast<CudaT*>(output_buffer.get()),
|
||||
reinterpret_cast<ToCudaType<bool>::MappedType*>(output_tensor->template MutableData<bool>()),
|
||||
output_size);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue