enable BN training in cpu inference build (#8269)

This commit is contained in:
Vincent Wang 2021-07-02 04:15:59 +08:00 committed by GitHub
parent 996a98b3ac
commit 9cfe642b34
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 31 deletions

View file

@ -32,25 +32,24 @@ template <typename T>
class BatchNorm : public OpKernel {
public:
explicit BatchNorm(const OpKernelInfo& op_kernel_info)
: OpKernel(op_kernel_info),
is_spatial_(op_kernel_info.GetAttrOrDefault<int64_t>("spatial", 1) == 1) {
auto st = op_kernel_info.GetAttr<float>("epsilon", &epsilon_);
ORT_ENFORCE(st.IsOK(), st.ErrorMessage());
: OpKernel(op_kernel_info), is_spatial_(op_kernel_info.GetAttrOrDefault<int64_t>("spatial", 1) == 1) {
epsilon_ = op_kernel_info.GetAttrOrDefault<float>("epsilon", 1e-5f);
momentum_ = op_kernel_info.GetAttrOrDefault<float>("momentum", 0.9f);
// For opset 6-8, if spatial attribute exists, pick up the value (by default spatial == 1)
// From opset 9 onwards, by default, only the spatial case (spatial == 1) is defined per spec
// For opset 14 onwards, training is an attribute.
// For opset < 14, since no training attribute is present we assume optional outputs indicate training mode.
if (op_kernel_info.node().SinceVersion() == 14) {
is_train_ = op_kernel_info.GetAttrOrDefault<int64_t>("training_mode", 0) == 1;
size_t output_count = op_kernel_info.node().OutputDefs().size();
ORT_ENFORCE((is_train_ && output_count == 3) || (!is_train_ && output_count == 1),
"Output running_mean and running_var are valid and required for training mode.");
} else {
is_train_ = OpKernel::Node().OutputDefs().size() > 1;
}
// For opset 14 onwards, training is true iff we have optional outputs present
// For opset < 14, since no training attribute is present we assume optional outputs indicate training mode
is_train_ = OpKernel::Node().OutputDefs().size() > 1;
#if defined(ENABLE_TRAINING)
ORT_ENFORCE(!is_train_ || is_spatial_, "Training mode does not support non-spatial BN");
auto mt = op_kernel_info.GetAttr<float>("momentum", &momentum_);
ORT_ENFORCE(mt.IsOK(), mt.ErrorMessage());
#else
ORT_ENFORCE(!is_train_, "Training mode is not supported in this build.");
#endif
}
Status Compute(OpKernelContext* p_op_kernel_context) const override {
@ -78,7 +77,6 @@ class BatchNorm : public OpKernel {
// calculate sample_size (including all channels)
size_t sample_size_incl_all_channels = sample_size * C;
#if defined(ENABLE_TRAINING)
AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(p_op_kernel_context->GetTempSpaceAllocator(&alloc));
@ -88,7 +86,8 @@ class BatchNorm : public OpKernel {
Tensor* saved_mean = is_train_ ? p_op_kernel_context->Output(3, mean->Shape()) : nullptr;
Tensor* saved_inv_std = is_train_ ? p_op_kernel_context->Output(4, var->Shape()) : nullptr;
// With opset <= 9, both must be defined in training. If opset >= 14, neither should be defined in training
ORT_ENFORCE(!is_train_ || ((!saved_mean && !saved_inv_std) || (saved_mean && saved_inv_std)), "Invalid number of outputs for BN training");
ORT_ENFORCE(!is_train_ || ((!saved_mean && !saved_inv_std) || (saved_mean && saved_inv_std)),
"Invalid number of outputs for BN training");
Tensor saved_mean_allocated, saved_inv_std_allocated;
if (is_train_ && !saved_mean) {
saved_mean_allocated = Tensor(DataTypeImpl::GetType<T>(), mean->Shape(), alloc);
@ -96,7 +95,6 @@ class BatchNorm : public OpKernel {
saved_mean = &saved_mean_allocated;
saved_inv_std = &saved_inv_std_allocated;
}
#endif
ConstEigenArrayMap<T> X_arr(X->template Data<T>(),
is_spatial_ ? sample_size : sample_size_incl_all_channels,
@ -104,7 +102,6 @@ class BatchNorm : public OpKernel {
ConstEigenVectorArrayMap<T> scale_arr(scale->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
ConstEigenVectorArrayMap<T> bias_arr(B->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
#if defined(ENABLE_TRAINING)
// Note that we only support spatial BN for training
if (is_train_) {
EigenVectorArrayMap<T> saved_mean_arr(saved_mean->template MutableData<T>(), C);
@ -143,7 +140,6 @@ class BatchNorm : public OpKernel {
running_mean_arr = input_running_mean_arr * momentum_ + saved_mean_arr * (1. - momentum_);
running_var_arr = input_running_var_arr * momentum_ + saved_var_arr * (1. - momentum_);
}
#endif
// Regardless of training or testing, we will apply the estimated mean
// and standard deviation to the input. For testing, they are
@ -155,21 +151,14 @@ class BatchNorm : public OpKernel {
ConstEigenVectorArrayMap<T> var_arr(var->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
inv_std = (var_arr + epsilon_).sqrt().inverse();
} else {
#if defined(ENABLE_TRAINING)
EigenVectorArrayMap<T> saved_inv_std_arr(saved_inv_std->template MutableData<T>(), C);
saved_inv_std_arr = (saved_inv_std_arr + epsilon_).inverse().sqrt();
inv_std = saved_inv_std_arr;
#endif
}
// If we're training, do batch normalization based on computation from this batch
ConstEigenVectorArrayMap<T> mean_arr(
#if defined(ENABLE_TRAINING)
!is_train_ ? mean->template Data<T>() : saved_mean->template Data<T>(),
#else
mean->template Data<T>(),
#endif
is_spatial_ ? C : sample_size_incl_all_channels);
ConstEigenVectorArrayMap<T> mean_arr(!is_train_ ? mean->template Data<T>() : saved_mean->template Data<T>(),
is_spatial_ ? C : sample_size_incl_all_channels);
// We can fuse the output computation as follows:
// ((x - est_mean) * (inv_var) * scale + bias
@ -195,7 +184,7 @@ class BatchNorm : public OpKernel {
protected:
float epsilon_;
float momentum_{0};
float momentum_;
const bool is_spatial_;
int64_t is_train_;
};

View file

@ -736,7 +736,6 @@ TEST(BatchNormTest, BatchNorm2d_fp16) {
#endif
// TODO fix flaky test for CUDA
#if defined(ENABLE_TRAINING)
TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) {
OpTester test("BatchNormalization", 9);
float epsilon = 1e-05f;
@ -790,7 +789,6 @@ TEST(BatchNormTest, ForwardTrainingTestOpset14) {
// exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm()
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
}
#endif
} // namespace test
} // namespace onnxruntime