Fix batch norm training op on CPU (#6946)

* Fix batch norm training op on CPU

* Add BatchNorm 14 Op Support

* Update hashes for BN

* Exclude TRT and OpenVINO for BatchNorm training test
This commit is contained in:
Pranav Prakash 2021-05-01 11:25:19 -07:00 committed by GitHub
parent 94c4c44bfc
commit 8ba6ed953f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 158 additions and 34 deletions

View file

@ -286,8 +286,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, MatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, MatMul);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int64_t, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, double, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, float, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, double, BatchNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, PRelu);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, float, Upsample);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, int32_t, Upsample);
@ -684,6 +684,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int64_t, Div);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, Reshape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, Identity);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, float, BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, double, BatchNormalization);
// !!PLEASE READ BELOW!! Following that, add new entries above this comment
@ -1143,9 +1145,9 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int64_t,
MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, float,
BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, double,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, double,
BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, PRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9,
@ -1758,6 +1760,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float,
Softmax)>,
// OpSet 14
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, float,
CumSum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, double,
@ -1797,6 +1801,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, Identity)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, float,
BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, double,
BatchNormalization)>,
};
for (auto& function_table_entry : function_table) {

View file

@ -29,11 +29,20 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 7, 8, double,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
BatchNorm<double>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 9, float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
// We alias the running mean to the mean so it stays preserved across multiple batches
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 9, 13, float,
KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
BatchNorm<float>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 9, double,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 9, 13, double,
KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
BatchNorm<double>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 14, float,
KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
BatchNorm<float>);
ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 14, double,
KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
BatchNorm<double>);
} // namespace onnxruntime

View file

@ -35,11 +35,15 @@ class BatchNorm : public OpKernel {
is_spatial_(op_kernel_info.GetAttrOrDefault<int64_t>("spatial", 1) == 1) {
auto st = op_kernel_info.GetAttr<float>("epsilon", &epsilon_);
ORT_ENFORCE(st.IsOK(), st.ErrorMessage());
auto mt = op_kernel_info.GetAttr<float>("momentum", &momentum_);
ORT_ENFORCE(mt.IsOK(), mt.ErrorMessage());
// For opset 6-8, if spatial attribute exists, pick up the value (by default spatial == 1)
// From opset 9 onwards, by default, only the spatial case (spatial == 1) is defined per spec
//TODO: momentum
// For opset 14 onwards, training is true iff we have optional outputs present
// For opset < 14, since no training attribute is present we assume optional outputs indicate training mode
is_train_ = OpKernel::Node().OutputDefs().size() > 1;
ORT_ENFORCE(!is_train_ || is_spatial_, "Training mode does not support non-spatial BN");
}
Status Compute(OpKernelContext* p_op_kernel_context) const override {
@ -67,17 +71,87 @@ class BatchNorm : public OpKernel {
// calculate sample_size (including all channels)
size_t sample_size_incl_all_channels = sample_size * C;
AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(p_op_kernel_context->GetTempSpaceAllocator(&alloc));
// Saved mean corresponds to the mean from this batch
// If these optional outputs are present (opset <= 9 or internal BN op) we re-use the space for calculations
// Note that with opset <= 9 we will be outputting saved_inv_std_dev instead of saved_var
Tensor *saved_mean = is_train_ ? p_op_kernel_context->Output(3, mean->Shape()) : nullptr;
Tensor *saved_inv_std = is_train_ ? p_op_kernel_context->Output(4, var->Shape()) : nullptr;
// With opset <= 9, both must be defined in training. If opset >= 14, neither should be defined in training
ORT_ENFORCE(!is_train_ || ((!saved_mean && !saved_inv_std) || (saved_mean && saved_inv_std)), "Invalid number of outputs for BN training");
Tensor saved_mean_allocated, saved_inv_std_allocated;
if (is_train_ && !saved_mean) {
saved_mean_allocated = Tensor(DataTypeImpl::GetType<T>(), mean->Shape(), alloc);
saved_inv_std_allocated = Tensor(DataTypeImpl::GetType<T>(), var->Shape(), alloc);
saved_mean = &saved_mean_allocated;
saved_inv_std = &saved_inv_std_allocated;
}
ConstEigenArrayMap<T> X_arr(X->template Data<T>(),
is_spatial_ ? sample_size : sample_size_incl_all_channels,
is_spatial_ ? N * C : N);
ConstEigenVectorArrayMap<T> scale_arr(scale->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
ConstEigenVectorArrayMap<T> bias_arr(B->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
// Note that we only support spatial BN for training
if (is_train_) {
EigenVectorArrayMap<T> saved_mean_arr(saved_mean->template MutableData<T>(), C);
// We first calculate saved_var then later take inverse square root to get saved_inv_std
EigenVectorArrayMap<T> saved_var_arr(saved_inv_std->template MutableData<T>(), C);
saved_mean_arr.setZero();
saved_var_arr.setZero();
for (size_t nc = 0; nc < N * C; ++nc) {
saved_mean_arr(nc % C) += X_arr.col(nc).sum();
}
saved_mean_arr /= static_cast<T>(N * sample_size);
for (size_t nc = 0; nc < N * C; ++nc) {
saved_var_arr(nc % C) += (X_arr.col(nc) - saved_mean_arr(nc % C)).matrix().squaredNorm();
}
saved_var_arr /= static_cast<T>(N * sample_size);
// The running mean corresponds to the mean from all the batches
// During inference this running mean is used as the mean for BN
auto* running_mean = p_op_kernel_context->Output(1, mean->Shape());
auto* running_var = p_op_kernel_context->Output(2, var->Shape());
const auto* input_running_mean = p_op_kernel_context->Input<Tensor>(3);
const auto* input_running_var = p_op_kernel_context->Input<Tensor>(4);
// Assume that running mean and variance are initialized properly in the model given to us
// Because we alias it, we have the past history here
EigenVectorArrayMap<T> running_mean_arr(
running_mean->template MutableData<T>(), C);
EigenVectorArrayMap<T> running_var_arr(
running_var->template MutableData<T>(), C);
ConstEigenVectorArrayMap<T> input_running_mean_arr(
input_running_mean->template Data<T>(), C);
ConstEigenVectorArrayMap<T> input_running_var_arr(
input_running_var->template Data<T>(), C);
running_mean_arr = input_running_mean_arr * momentum_ + saved_mean_arr * (1. - momentum_);
running_var_arr = input_running_var_arr * momentum_ + saved_var_arr * (1. - momentum_);
}
// Regardless of training or testing, we will apply the estimated mean
// and standard deviation to the input. For testing, they are
// specified directly by the input, and for training, they are computed
// by the op.
Eigen::Array<T, Eigen::Dynamic, 1> inv_std(is_spatial_ ? C : sample_size_incl_all_channels);
ConstEigenVectorArrayMap<T> var_arr(var->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
inv_std = (var_arr + epsilon_).sqrt().inverse();
ConstEigenVectorArrayMap<T> mean_arr(mean->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
if (!is_train_) {
ConstEigenVectorArrayMap<T> var_arr(var->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
inv_std = (var_arr + epsilon_).sqrt().inverse();
} else {
EigenVectorArrayMap<T> saved_inv_std_arr(saved_inv_std->template MutableData<T>(), C);
saved_inv_std_arr = (saved_inv_std_arr + epsilon_).inverse().sqrt();
inv_std = saved_inv_std_arr;
}
// If we're training, do batch normalization based on computation from this batch
ConstEigenVectorArrayMap<T> mean_arr(
!is_train_ ? mean->template Data<T>() : saved_mean->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
// We can fuse the output computation as follows:
// ((x - est_mean) * (inv_var) * scale + bias
// to
@ -87,9 +161,7 @@ class BatchNorm : public OpKernel {
EigenArrayMap<T> Y_arr(Y->template MutableData<T>(),
is_spatial_ ? sample_size : sample_size_incl_all_channels,
is_spatial_ ? N * C : N);
ConstEigenArrayMap<T> X_arr(X->template Data<T>(),
is_spatial_ ? sample_size : sample_size_incl_all_channels,
is_spatial_ ? N * C : N);
if (is_spatial_) { // spatial == 1
for (size_t nc = 0; nc < N * C; ++nc) {
Y_arr.col(nc) = X_arr.col(nc) * new_scale(nc % C) + new_bias(nc % C);
@ -99,13 +171,13 @@ class BatchNorm : public OpKernel {
Y_arr.col(n) = X_arr.col(n) * new_scale.col(0) + new_bias.col(0);
}
}
return Status::OK();
}
protected:
float epsilon_;
float momentum_;
const bool is_spatial_;
//int64_t is_test_; ignored in this implementation since we're doing inferencing only.
int64_t is_train_;
};
} // namespace onnxruntime

View file

@ -733,16 +733,15 @@ TEST(BatchNormTest, BatchNorm2d_fp16) {
test.AddOutput<MLFloat16>("output", input_shape, f_output);
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}
#endif
// TODO fix flaky test
TEST(BatchNormTest, DISABLED_ForwardTrainingTest) {
OpTester test("BatchNormalization");
// TODO fix flaky test for CUDA
TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) {
OpTester test("BatchNormalization", 9);
float epsilon = 1e-05f;
float momentum = 0.1f;
int64_t spatial = 1;
test.AddAttribute("epsilon", epsilon);
test.AddAttribute("momentum", momentum);
test.AddAttribute("spatial", spatial);
std::vector<int64_t> input_output_dims{2, 2, 2, 2};
std::vector<int64_t> channel_dims{2};
test.AddInput<float>("X", input_output_dims, {-0.2953f, 0.1180f, 1.0973f, -0.1931f, -0.1999f, -0.0237f, 1.5181f, 0.0076f, -1.0830f, -1.5433f, 0.4327f, -0.9813f, 0.7875f, -0.4080f, -2.3144f, 1.5493f});
@ -750,19 +749,46 @@ TEST(BatchNormTest, DISABLED_ForwardTrainingTest) {
test.AddInput<float>("B", channel_dims, {0.0f, 0.0f});
test.AddInput<float>("mean", channel_dims, {1.0f, 2.0f});
test.AddInput<float>("var", channel_dims, {1.0f, 2.0f});
// values from PyTorch with affine=False, track_running_stats=True flags
test.AddOutput<float>("Y", input_output_dims, {0.0131f, 0.5210f, 1.7244f, 0.1387f, -0.2708f, -0.1191f, 1.2089f, -0.0922f, -0.9548f, -1.5203f, 0.9077f, -0.8298f, 0.5796f, -0.4501f, -2.0921f, 1.2358f});
// in PyTorch, running_mean and running_var should be initialized to [0.0, 0.0]
test.AddOutput<float>("running_mean", channel_dims, {-0.0306f, 0.0115f});
test.AddOutput<float>("running_var", channel_dims, {0.0757f, 0.1541f});
// mean and variance of X across channel dimension
test.AddOutput<float>("saved_mean", channel_dims, {-0.306f, 0.115f});
test.AddOutput<float>("saved_var", channel_dims, {1.229f, 0.861f});
// exclude CPU Execution Provider so that test is run with CUDA with ForwardTraining mode
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider});
test.AddOutput<float>("Y", input_output_dims, {0.0131f, 0.5210f, 1.7244f, 0.1387f, -0.2708f, -0.1191f, 1.2089f, -0.0922f, -0.9548f, -1.5203f, 0.9077f, -0.8298f, 0.5796f, -0.4501f, -2.0921f, 1.2358f});
test.AddOutput<float>("running_mean", channel_dims, {-0.1754f, 0.303106f});
test.AddOutput<float>("running_var", channel_dims, {0.696052f, 1.41316f});
// mean and variance of X across channel dimension
// With Opset9 we output saved_inv_std instead of saved_var to match CUDA EP
test.AddOutput<float>("saved_mean", channel_dims, {-0.306f, 0.114562f});
test.AddOutput<float>("saved_inv_std", channel_dims, {1.2288f, 0.861317f});
// exclude CUDA Execution Provider due to flakiness
// exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm()
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
}
TEST(BatchNormTest, ForwardTrainingTestOpset14) {
OpTester test("BatchNormalization", 14);
float epsilon = 1e-05f;
float momentum = 0.1f;
int64_t training_mode = 1;
test.AddAttribute("epsilon", epsilon);
test.AddAttribute("momentum", momentum);
test.AddAttribute("training_mode", training_mode);
std::vector<int64_t> input_output_dims{2, 2, 2, 2};
std::vector<int64_t> channel_dims{2};
test.AddInput<float>("X", input_output_dims, {-0.2953f, 0.1180f, 1.0973f, -0.1931f, -0.1999f, -0.0237f, 1.5181f, 0.0076f, -1.0830f, -1.5433f, 0.4327f, -0.9813f, 0.7875f, -0.4080f, -2.3144f, 1.5493f});
test.AddInput<float>("scale", channel_dims, {1.0f, 1.0f});
test.AddInput<float>("B", channel_dims, {0.0f, 0.0f});
test.AddInput<float>("mean", channel_dims, {1.0f, 2.0f});
test.AddInput<float>("var", channel_dims, {1.0f, 2.0f});
test.AddOutput<float>("Y", input_output_dims, {0.0131f, 0.5210f, 1.7244f, 0.1387f, -0.2708f, -0.1191f, 1.2089f, -0.0922f, -0.9548f, -1.5203f, 0.9077f, -0.8298f, 0.5796f, -0.4501f, -2.0921f, 1.2358f});
test.AddOutput<float>("running_mean", channel_dims, {-0.1754f, 0.303106f});
test.AddOutput<float>("running_var", channel_dims, {0.696052f, 1.41316f});
// exclude CUDA Execution Provider due to flakiness
// exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm()
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
}
#endif
} // namespace test
} // namespace onnxruntime

View file

@ -247,6 +247,14 @@
"BatchNormalization ai.onnx CPUExecutionProvider",
18128921553709069152
],
[
"BatchNormalization ai.onnx CPUExecutionProvider",
13094179255141648608
],
[
"BatchNormalization ai.onnx CPUExecutionProvider",
17832136363477464736
],
[
"BitShift ai.onnx CPUExecutionProvider",
4758677670685660688

View file

@ -29,6 +29,7 @@ static std::unordered_map<std::string, std::unordered_set<size_t>>
STOP_GRADIENT_EDGES = {
{"Not", {0}},
{"And", {0, 1}},
{"BatchNormalization", {3, 4}},
{"Or", {0, 1}},
{"Xor", {0, 1}},
{"Equal", {0, 1}},