mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
Use cuda memset async (#21216)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
0bbd061a54
commit
3f6b7430d6
1 changed files with 3 additions and 1 deletions
|
|
@ -35,7 +35,8 @@ struct PadAndUnflattenFunctor {
|
|||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
const CudaT* input_data = reinterpret_cast<const CudaT*>(input_tensor.Data<T>());
|
||||
|
||||
CUDA_CALL_THROW(cudaMemset(output_tensor.MutableDataRaw(), 0, output_tensor.Shape().Size() * sizeof(CudaT)));
|
||||
CUDA_CALL_THROW(cudaMemsetAsync(output_tensor.MutableDataRaw(), 0, output_tensor.Shape().Size() * sizeof(CudaT),
|
||||
stream));
|
||||
PadAndUnflattenImpl<CudaT>(stream, input_element_count, output_element_stride_fdm, index_value_upper_bound,
|
||||
input_data, indices_tensor.Data<int64_t>(),
|
||||
reinterpret_cast<CudaT*>(output_tensor.MutableData<T>()));
|
||||
|
|
@ -48,6 +49,7 @@ Status PadAndUnflatten::ComputeInternal(OpKernelContext* context) const {
|
|||
const Tensor* input_tensor = context->Input<Tensor>(0);
|
||||
const Tensor* indices_tensor = context->Input<Tensor>(1);
|
||||
const Tensor* unflatten_dims_tensor = context->Input<Tensor>(2); // Parse the 1-D shape tensor.
|
||||
|
||||
ORT_ENFORCE(unflatten_dims_tensor->Shape().NumDimensions() == 1,
|
||||
"unflatten_dims_tensor tensor must be 1-D.", unflatten_dims_tensor->Shape().NumDimensions());
|
||||
ORT_ENFORCE(unflatten_dims_tensor->Shape().Size() == 2,
|
||||
|
|
|
|||
Loading…
Reference in a new issue