diff --git a/caffe2/operators/spatial_batch_norm_op_cudnn.cu b/caffe2/operators/spatial_batch_norm_op_cudnn.cu index 702a9a629d6..4b629caa8cc 100644 --- a/caffe2/operators/spatial_batch_norm_op_cudnn.cu +++ b/caffe2/operators/spatial_batch_norm_op_cudnn.cu @@ -191,6 +191,66 @@ class CuDNNSpatialBNOp final : public SpatialBNOp { return true; } const double alpha = static_cast(1.0f - momentum_); + +#if CUDNN_VERSION_MIN(8, 0, 0) + // Currently not supporting CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION + auto op = CUDNN_BATCHNORM_OPS_BN; + + // Calculate the workspace size + size_t workspace_size; + CUDNN_ENFORCE(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( + cudnn_wrapper_.inline_cudnn_handle(), + mode_, + op, + data_desc_, + NULL, + data_desc_, + param_desc_, + NULL, + &workspace_size)); + + // Calculate the reserved space size - common function for forward and backward + size_t reserve_size; + CUDNN_ENFORCE(cudnnGetBatchNormalizationTrainingExReserveSpaceSize( + cudnn_wrapper_.inline_cudnn_handle(), + mode_, + op, + NULL, + data_desc_, + &reserve_size)); + + // CUDNN state is needed to access the workspace + size_t cudnn_state_(OperatorBase::GetSingleArgument("cudnn_state", 0)); + cudnn_wrapper_.with_cudnn_state( + cudnn_state_, [&](CuDNNState* state) { + CUDNN_ENFORCE(cudnnBatchNormalizationForwardTrainingEx( + cudnn_wrapper_.inline_cudnn_handle(), + mode_, + CUDNN_BATCHNORM_OPS_BN, + cudnnTypeWrapper::kOne(), + cudnnTypeWrapper::kZero(), + data_desc_, + X_data, + NULL, + NULL, + data_desc_, + Y_data, + param_desc_, + scale_data, + bias_data, + alpha, + running_mean_data, + running_var_data, + epsilon_, + saved_mean_data, + saved_inv_std_data, + NULL, + state->workspace().get(workspace_size), + workspace_size, + state->workspace().get(reserve_size), + reserve_size)); + }); +#else CUDNN_ENFORCE(cudnnBatchNormalizationForwardTraining( cudnn_wrapper_.inline_cudnn_handle(), mode_, @@ -209,6 +269,7 @@ class CuDNNSpatialBNOp final : public SpatialBNOp { epsilon_, saved_mean_data, saved_inv_std_data)); +#endif // CUDNN_VERSION_MIN(8, 0, 0) } return true; } @@ -314,6 +375,71 @@ class CuDNNSpatialBNGradientOp final : public SpatialBNGradientOp { data_desc_, param_desc_); } +#if CUDNN_VERSION_MIN(8, 0, 0) + // Currently not supporting CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION + auto op = CUDNN_BATCHNORM_OPS_BN; + + size_t workspace_size; + CUDNN_ENFORCE(cudnnGetBatchNormalizationBackwardExWorkspaceSize( + cudnn_wrapper_.inline_cudnn_handle(), + mode_, + op, + data_desc_, + NULL, + data_desc_, + NULL, + data_desc_, + param_desc_, + NULL, + &workspace_size)); + + // Calculate the reserved space size - common function for forward and backward + size_t reserve_size; + CUDNN_ENFORCE(cudnnGetBatchNormalizationTrainingExReserveSpaceSize( + cudnn_wrapper_.inline_cudnn_handle(), + mode_, + op, + NULL, + data_desc_, + &reserve_size)); + + // CUDNN state is needed to access the workspace + size_t cudnn_state_(OperatorBase::GetSingleArgument("cudnn_state", 0)); + cudnn_wrapper_.with_cudnn_state( + cudnn_state_, [&](CuDNNState* state) { + CUDNN_ENFORCE(cudnnBatchNormalizationBackwardEx( + cudnn_wrapper_.inline_cudnn_handle(), + mode_, + op, + cudnnTypeWrapper::kOne(), + cudnnTypeWrapper::kZero(), + cudnnTypeWrapper::kOne(), + cudnnTypeWrapper::kZero(), + data_desc_, + X_data, + NULL, + NULL, + data_desc_, + dY_data, + NULL, + NULL, + data_desc_, + dX_data, + param_desc_, + scale_data, + NULL, + dscale_data, + dbias_data, + epsilon_, + saved_mean_data, + saved_rstd_data, + NULL, + state->workspace().get(workspace_size), + workspace_size, + state->workspace().get(reserve_size), + reserve_size)); + }); +#else CUDNN_ENFORCE(cudnnBatchNormalizationBackward( cudnn_wrapper_.inline_cudnn_handle(), mode_, @@ -334,7 +460,7 @@ class CuDNNSpatialBNGradientOp final : public SpatialBNGradientOp { epsilon_, saved_mean_data, saved_rstd_data)); - +#endif // CUDNN_VERSION_MIN(8, 0, 0) return true; }