From 9cfe642b342f5405f9da370ca09292ff381ea7c8 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Fri, 2 Jul 2021 04:15:59 +0800 Subject: [PATCH] enable BN training in cpu inference build (#8269) --- .../core/providers/cpu/nn/batch_norm.h | 47 +++++++------------ .../providers/cpu/nn/batch_norm_op_test.cc | 2 - 2 files changed, 18 insertions(+), 31 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm.h b/onnxruntime/core/providers/cpu/nn/batch_norm.h index 51b3b516cd..07c70049d9 100644 --- a/onnxruntime/core/providers/cpu/nn/batch_norm.h +++ b/onnxruntime/core/providers/cpu/nn/batch_norm.h @@ -32,25 +32,24 @@ template class BatchNorm : public OpKernel { public: explicit BatchNorm(const OpKernelInfo& op_kernel_info) - : OpKernel(op_kernel_info), - is_spatial_(op_kernel_info.GetAttrOrDefault("spatial", 1) == 1) { - auto st = op_kernel_info.GetAttr("epsilon", &epsilon_); - ORT_ENFORCE(st.IsOK(), st.ErrorMessage()); + : OpKernel(op_kernel_info), is_spatial_(op_kernel_info.GetAttrOrDefault("spatial", 1) == 1) { + epsilon_ = op_kernel_info.GetAttrOrDefault("epsilon", 1e-5f); + momentum_ = op_kernel_info.GetAttrOrDefault("momentum", 0.9f); // For opset 6-8, if spatial attribute exists, pick up the value (by default spatial == 1) // From opset 9 onwards, by default, only the spatial case (spatial == 1) is defined per spec + // For opset 14 onwards, training is an attribute. + // For opset < 14, since no training attribute is present we assume optional outputs indicate training mode. + if (op_kernel_info.node().SinceVersion() == 14) { + is_train_ = op_kernel_info.GetAttrOrDefault("training_mode", 0) == 1; + size_t output_count = op_kernel_info.node().OutputDefs().size(); + ORT_ENFORCE((is_train_ && output_count == 3) || (!is_train_ && output_count == 1), + "Output running_mean and running_var are valid and required for training mode."); + } else { + is_train_ = OpKernel::Node().OutputDefs().size() > 1; + } - // For opset 14 onwards, training is true iff we have optional outputs present - // For opset < 14, since no training attribute is present we assume optional outputs indicate training mode - is_train_ = OpKernel::Node().OutputDefs().size() > 1; -#if defined(ENABLE_TRAINING) ORT_ENFORCE(!is_train_ || is_spatial_, "Training mode does not support non-spatial BN"); - - auto mt = op_kernel_info.GetAttr("momentum", &momentum_); - ORT_ENFORCE(mt.IsOK(), mt.ErrorMessage()); -#else - ORT_ENFORCE(!is_train_, "Training mode is not supported in this build."); -#endif } Status Compute(OpKernelContext* p_op_kernel_context) const override { @@ -78,7 +77,6 @@ class BatchNorm : public OpKernel { // calculate sample_size (including all channels) size_t sample_size_incl_all_channels = sample_size * C; -#if defined(ENABLE_TRAINING) AllocatorPtr alloc; ORT_RETURN_IF_ERROR(p_op_kernel_context->GetTempSpaceAllocator(&alloc)); @@ -88,7 +86,8 @@ class BatchNorm : public OpKernel { Tensor* saved_mean = is_train_ ? p_op_kernel_context->Output(3, mean->Shape()) : nullptr; Tensor* saved_inv_std = is_train_ ? p_op_kernel_context->Output(4, var->Shape()) : nullptr; // With opset <= 9, both must be defined in training. If opset >= 14, neither should be defined in training - ORT_ENFORCE(!is_train_ || ((!saved_mean && !saved_inv_std) || (saved_mean && saved_inv_std)), "Invalid number of outputs for BN training"); + ORT_ENFORCE(!is_train_ || ((!saved_mean && !saved_inv_std) || (saved_mean && saved_inv_std)), + "Invalid number of outputs for BN training"); Tensor saved_mean_allocated, saved_inv_std_allocated; if (is_train_ && !saved_mean) { saved_mean_allocated = Tensor(DataTypeImpl::GetType(), mean->Shape(), alloc); @@ -96,7 +95,6 @@ class BatchNorm : public OpKernel { saved_mean = &saved_mean_allocated; saved_inv_std = &saved_inv_std_allocated; } -#endif ConstEigenArrayMap X_arr(X->template Data(), is_spatial_ ? sample_size : sample_size_incl_all_channels, @@ -104,7 +102,6 @@ class BatchNorm : public OpKernel { ConstEigenVectorArrayMap scale_arr(scale->template Data(), is_spatial_ ? C : sample_size_incl_all_channels); ConstEigenVectorArrayMap bias_arr(B->template Data(), is_spatial_ ? C : sample_size_incl_all_channels); -#if defined(ENABLE_TRAINING) // Note that we only support spatial BN for training if (is_train_) { EigenVectorArrayMap saved_mean_arr(saved_mean->template MutableData(), C); @@ -143,7 +140,6 @@ class BatchNorm : public OpKernel { running_mean_arr = input_running_mean_arr * momentum_ + saved_mean_arr * (1. - momentum_); running_var_arr = input_running_var_arr * momentum_ + saved_var_arr * (1. - momentum_); } -#endif // Regardless of training or testing, we will apply the estimated mean // and standard deviation to the input. For testing, they are @@ -155,21 +151,14 @@ class BatchNorm : public OpKernel { ConstEigenVectorArrayMap var_arr(var->template Data(), is_spatial_ ? C : sample_size_incl_all_channels); inv_std = (var_arr + epsilon_).sqrt().inverse(); } else { -#if defined(ENABLE_TRAINING) EigenVectorArrayMap saved_inv_std_arr(saved_inv_std->template MutableData(), C); saved_inv_std_arr = (saved_inv_std_arr + epsilon_).inverse().sqrt(); inv_std = saved_inv_std_arr; -#endif } // If we're training, do batch normalization based on computation from this batch - ConstEigenVectorArrayMap mean_arr( -#if defined(ENABLE_TRAINING) - !is_train_ ? mean->template Data() : saved_mean->template Data(), -#else - mean->template Data(), -#endif - is_spatial_ ? C : sample_size_incl_all_channels); + ConstEigenVectorArrayMap mean_arr(!is_train_ ? mean->template Data() : saved_mean->template Data(), + is_spatial_ ? C : sample_size_incl_all_channels); // We can fuse the output computation as follows: // ((x - est_mean) * (inv_var) * scale + bias @@ -195,7 +184,7 @@ class BatchNorm : public OpKernel { protected: float epsilon_; - float momentum_{0}; + float momentum_; const bool is_spatial_; int64_t is_train_; }; diff --git a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc index cfd17ffccb..33b9abe98d 100644 --- a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc @@ -736,7 +736,6 @@ TEST(BatchNormTest, BatchNorm2d_fp16) { #endif // TODO fix flaky test for CUDA -#if defined(ENABLE_TRAINING) TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) { OpTester test("BatchNormalization", 9); float epsilon = 1e-05f; @@ -790,7 +789,6 @@ TEST(BatchNormTest, ForwardTrainingTestOpset14) { // exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm() test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } -#endif } // namespace test } // namespace onnxruntime