mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-24 22:17:32 +00:00
parent
24f3d72b77
commit
6749e9fd44
1 changed files with 9 additions and 0 deletions
|
|
@ -100,10 +100,19 @@ Status InstanceNorm<T>::ComputeInternal(OpKernelContext* p_op_kernel_context) co
|
|||
CudnnTensor stats_desc;
|
||||
ORT_RETURN_IF_ERROR(stats_desc.Set(std::array<int64_t, 4>{1, stats_count, 1, 1}, CudnnTensor::GetDataType<CudaT>()));
|
||||
|
||||
const size_t stats_byte_count = stats_count * sizeof(CudaT);
|
||||
|
||||
// Mean & Variance are inputs & outputs and must be initialized to zero to work properly
|
||||
auto mean = GetScratchBuffer<CudaT>(stats_count);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(mean.get(), 0, stats_byte_count, Stream()));
|
||||
auto variance = GetScratchBuffer<CudaT>(stats_count);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(variance.get(), 0, stats_byte_count, Stream()));
|
||||
|
||||
// We must set the scale & bias inputs to zero as they are inputs to the calculation
|
||||
auto unused_scale = GetScratchBuffer<CudaT>(stats_count);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(unused_scale.get(), 0, stats_byte_count, Stream()));
|
||||
auto unused_bias = GetScratchBuffer<CudaT>(stats_count);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(unused_bias.get(), 0, stats_byte_count, Stream()));
|
||||
|
||||
// first, compute mean and variance per-instance per-channel using cudnnBatchNorm training
|
||||
CUDNN_RETURN_IF_ERROR(cudnnBatchNormalizationForwardTraining(
|
||||
|
|
|
|||
Loading…
Reference in a new issue