Cuda instance_norm fix (#9826)

* Fix allocation size & initial values
This commit is contained in:
Ryan Hill 2021-11-22 22:59:20 -08:00 committed by GitHub
parent 24f3d72b77
commit 6749e9fd44
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -100,10 +100,19 @@ Status InstanceNorm<T>::ComputeInternal(OpKernelContext* p_op_kernel_context) co
CudnnTensor stats_desc;
ORT_RETURN_IF_ERROR(stats_desc.Set(std::array<int64_t, 4>{1, stats_count, 1, 1}, CudnnTensor::GetDataType<CudaT>()));
const size_t stats_byte_count = stats_count * sizeof(CudaT);
// Mean & Variance are inputs & outputs and must be initialized to zero to work properly
auto mean = GetScratchBuffer<CudaT>(stats_count);
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(mean.get(), 0, stats_byte_count, Stream()));
auto variance = GetScratchBuffer<CudaT>(stats_count);
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(variance.get(), 0, stats_byte_count, Stream()));
// We must set the scale & bias inputs to zero as they are inputs to the calculation
auto unused_scale = GetScratchBuffer<CudaT>(stats_count);
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(unused_scale.get(), 0, stats_byte_count, Stream()));
auto unused_bias = GetScratchBuffer<CudaT>(stats_count);
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(unused_bias.get(), 0, stats_byte_count, Stream()));
// first, compute mean and variance per-instance per-channel using cudnnBatchNorm training
CUDNN_RETURN_IF_ERROR(cudnnBatchNormalizationForwardTraining(