Use cuda memset async (#21216)

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
pengwa 2024-07-05 17:27:45 +08:00 committed by GitHub
parent 0bbd061a54
commit 3f6b7430d6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -35,7 +35,8 @@ struct PadAndUnflattenFunctor {
typedef typename ToCudaType<T>::MappedType CudaT;
const CudaT* input_data = reinterpret_cast<const CudaT*>(input_tensor.Data<T>());
CUDA_CALL_THROW(cudaMemset(output_tensor.MutableDataRaw(), 0, output_tensor.Shape().Size() * sizeof(CudaT)));
CUDA_CALL_THROW(cudaMemsetAsync(output_tensor.MutableDataRaw(), 0, output_tensor.Shape().Size() * sizeof(CudaT),
stream));
PadAndUnflattenImpl<CudaT>(stream, input_element_count, output_element_stride_fdm, index_value_upper_bound,
input_data, indices_tensor.Data<int64_t>(),
reinterpret_cast<CudaT*>(output_tensor.MutableData<T>()));
@ -48,6 +49,7 @@ Status PadAndUnflatten::ComputeInternal(OpKernelContext* context) const {
const Tensor* input_tensor = context->Input<Tensor>(0);
const Tensor* indices_tensor = context->Input<Tensor>(1);
const Tensor* unflatten_dims_tensor = context->Input<Tensor>(2); // Parse the 1-D shape tensor.
ORT_ENFORCE(unflatten_dims_tensor->Shape().NumDimensions() == 1,
"unflatten_dims_tensor tensor must be 1-D.", unflatten_dims_tensor->Shape().NumDimensions());
ORT_ENFORCE(unflatten_dims_tensor->Shape().Size() == 2,