mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
enable BN training in cpu inference build (#8269)
This commit is contained in:
parent
996a98b3ac
commit
9cfe642b34
2 changed files with 18 additions and 31 deletions
|
|
@ -32,25 +32,24 @@ template <typename T>
|
|||
class BatchNorm : public OpKernel {
|
||||
public:
|
||||
explicit BatchNorm(const OpKernelInfo& op_kernel_info)
|
||||
: OpKernel(op_kernel_info),
|
||||
is_spatial_(op_kernel_info.GetAttrOrDefault<int64_t>("spatial", 1) == 1) {
|
||||
auto st = op_kernel_info.GetAttr<float>("epsilon", &epsilon_);
|
||||
ORT_ENFORCE(st.IsOK(), st.ErrorMessage());
|
||||
: OpKernel(op_kernel_info), is_spatial_(op_kernel_info.GetAttrOrDefault<int64_t>("spatial", 1) == 1) {
|
||||
epsilon_ = op_kernel_info.GetAttrOrDefault<float>("epsilon", 1e-5f);
|
||||
momentum_ = op_kernel_info.GetAttrOrDefault<float>("momentum", 0.9f);
|
||||
|
||||
// For opset 6-8, if spatial attribute exists, pick up the value (by default spatial == 1)
|
||||
// From opset 9 onwards, by default, only the spatial case (spatial == 1) is defined per spec
|
||||
// For opset 14 onwards, training is an attribute.
|
||||
// For opset < 14, since no training attribute is present we assume optional outputs indicate training mode.
|
||||
if (op_kernel_info.node().SinceVersion() == 14) {
|
||||
is_train_ = op_kernel_info.GetAttrOrDefault<int64_t>("training_mode", 0) == 1;
|
||||
size_t output_count = op_kernel_info.node().OutputDefs().size();
|
||||
ORT_ENFORCE((is_train_ && output_count == 3) || (!is_train_ && output_count == 1),
|
||||
"Output running_mean and running_var are valid and required for training mode.");
|
||||
} else {
|
||||
is_train_ = OpKernel::Node().OutputDefs().size() > 1;
|
||||
}
|
||||
|
||||
// For opset 14 onwards, training is true iff we have optional outputs present
|
||||
// For opset < 14, since no training attribute is present we assume optional outputs indicate training mode
|
||||
is_train_ = OpKernel::Node().OutputDefs().size() > 1;
|
||||
#if defined(ENABLE_TRAINING)
|
||||
ORT_ENFORCE(!is_train_ || is_spatial_, "Training mode does not support non-spatial BN");
|
||||
|
||||
auto mt = op_kernel_info.GetAttr<float>("momentum", &momentum_);
|
||||
ORT_ENFORCE(mt.IsOK(), mt.ErrorMessage());
|
||||
#else
|
||||
ORT_ENFORCE(!is_train_, "Training mode is not supported in this build.");
|
||||
#endif
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* p_op_kernel_context) const override {
|
||||
|
|
@ -78,7 +77,6 @@ class BatchNorm : public OpKernel {
|
|||
// calculate sample_size (including all channels)
|
||||
size_t sample_size_incl_all_channels = sample_size * C;
|
||||
|
||||
#if defined(ENABLE_TRAINING)
|
||||
AllocatorPtr alloc;
|
||||
ORT_RETURN_IF_ERROR(p_op_kernel_context->GetTempSpaceAllocator(&alloc));
|
||||
|
||||
|
|
@ -88,7 +86,8 @@ class BatchNorm : public OpKernel {
|
|||
Tensor* saved_mean = is_train_ ? p_op_kernel_context->Output(3, mean->Shape()) : nullptr;
|
||||
Tensor* saved_inv_std = is_train_ ? p_op_kernel_context->Output(4, var->Shape()) : nullptr;
|
||||
// With opset <= 9, both must be defined in training. If opset >= 14, neither should be defined in training
|
||||
ORT_ENFORCE(!is_train_ || ((!saved_mean && !saved_inv_std) || (saved_mean && saved_inv_std)), "Invalid number of outputs for BN training");
|
||||
ORT_ENFORCE(!is_train_ || ((!saved_mean && !saved_inv_std) || (saved_mean && saved_inv_std)),
|
||||
"Invalid number of outputs for BN training");
|
||||
Tensor saved_mean_allocated, saved_inv_std_allocated;
|
||||
if (is_train_ && !saved_mean) {
|
||||
saved_mean_allocated = Tensor(DataTypeImpl::GetType<T>(), mean->Shape(), alloc);
|
||||
|
|
@ -96,7 +95,6 @@ class BatchNorm : public OpKernel {
|
|||
saved_mean = &saved_mean_allocated;
|
||||
saved_inv_std = &saved_inv_std_allocated;
|
||||
}
|
||||
#endif
|
||||
|
||||
ConstEigenArrayMap<T> X_arr(X->template Data<T>(),
|
||||
is_spatial_ ? sample_size : sample_size_incl_all_channels,
|
||||
|
|
@ -104,7 +102,6 @@ class BatchNorm : public OpKernel {
|
|||
ConstEigenVectorArrayMap<T> scale_arr(scale->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
|
||||
ConstEigenVectorArrayMap<T> bias_arr(B->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
|
||||
|
||||
#if defined(ENABLE_TRAINING)
|
||||
// Note that we only support spatial BN for training
|
||||
if (is_train_) {
|
||||
EigenVectorArrayMap<T> saved_mean_arr(saved_mean->template MutableData<T>(), C);
|
||||
|
|
@ -143,7 +140,6 @@ class BatchNorm : public OpKernel {
|
|||
running_mean_arr = input_running_mean_arr * momentum_ + saved_mean_arr * (1. - momentum_);
|
||||
running_var_arr = input_running_var_arr * momentum_ + saved_var_arr * (1. - momentum_);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Regardless of training or testing, we will apply the estimated mean
|
||||
// and standard deviation to the input. For testing, they are
|
||||
|
|
@ -155,21 +151,14 @@ class BatchNorm : public OpKernel {
|
|||
ConstEigenVectorArrayMap<T> var_arr(var->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
|
||||
inv_std = (var_arr + epsilon_).sqrt().inverse();
|
||||
} else {
|
||||
#if defined(ENABLE_TRAINING)
|
||||
EigenVectorArrayMap<T> saved_inv_std_arr(saved_inv_std->template MutableData<T>(), C);
|
||||
saved_inv_std_arr = (saved_inv_std_arr + epsilon_).inverse().sqrt();
|
||||
inv_std = saved_inv_std_arr;
|
||||
#endif
|
||||
}
|
||||
|
||||
// If we're training, do batch normalization based on computation from this batch
|
||||
ConstEigenVectorArrayMap<T> mean_arr(
|
||||
#if defined(ENABLE_TRAINING)
|
||||
!is_train_ ? mean->template Data<T>() : saved_mean->template Data<T>(),
|
||||
#else
|
||||
mean->template Data<T>(),
|
||||
#endif
|
||||
is_spatial_ ? C : sample_size_incl_all_channels);
|
||||
ConstEigenVectorArrayMap<T> mean_arr(!is_train_ ? mean->template Data<T>() : saved_mean->template Data<T>(),
|
||||
is_spatial_ ? C : sample_size_incl_all_channels);
|
||||
|
||||
// We can fuse the output computation as follows:
|
||||
// ((x - est_mean) * (inv_var) * scale + bias
|
||||
|
|
@ -195,7 +184,7 @@ class BatchNorm : public OpKernel {
|
|||
|
||||
protected:
|
||||
float epsilon_;
|
||||
float momentum_{0};
|
||||
float momentum_;
|
||||
const bool is_spatial_;
|
||||
int64_t is_train_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -736,7 +736,6 @@ TEST(BatchNormTest, BatchNorm2d_fp16) {
|
|||
#endif
|
||||
|
||||
// TODO fix flaky test for CUDA
|
||||
#if defined(ENABLE_TRAINING)
|
||||
TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) {
|
||||
OpTester test("BatchNormalization", 9);
|
||||
float epsilon = 1e-05f;
|
||||
|
|
@ -790,7 +789,6 @@ TEST(BatchNormTest, ForwardTrainingTestOpset14) {
|
|||
// exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm()
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue