diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index f645e638a9..910bdc117d 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -286,8 +286,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int64_t, MatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, BatchNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, double, BatchNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, PRelu); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, float, Upsample); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, int32_t, Upsample); @@ -684,6 +684,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int64_t, Div); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, Reshape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, Identity); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, float, BatchNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, double, BatchNormalization); // !!PLEASE READ BELOW!! Following that, add new entries above this comment @@ -1143,9 +1145,9 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { MatMul)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + // OpSet 14 BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm.cc b/onnxruntime/core/providers/cpu/nn/batch_norm.cc index 50d0bc21df..92fff6974b 100644 --- a/onnxruntime/core/providers/cpu/nn/batch_norm.cc +++ b/onnxruntime/core/providers/cpu/nn/batch_norm.cc @@ -29,11 +29,20 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 7, 8, double, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), BatchNorm); -ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 9, float, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), +// We alias the running mean to the mean so it stays preserved across multiple batches +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 9, 13, float, + KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType()), BatchNorm); -ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 9, double, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 9, 13, double, + KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType()), + BatchNorm); + +ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 14, float, + KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType()), + BatchNorm); + +ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 14, double, + KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType()), BatchNorm); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm.h b/onnxruntime/core/providers/cpu/nn/batch_norm.h index 46ca31053b..9bf4283ac9 100644 --- a/onnxruntime/core/providers/cpu/nn/batch_norm.h +++ b/onnxruntime/core/providers/cpu/nn/batch_norm.h @@ -35,11 +35,15 @@ class BatchNorm : public OpKernel { is_spatial_(op_kernel_info.GetAttrOrDefault("spatial", 1) == 1) { auto st = op_kernel_info.GetAttr("epsilon", &epsilon_); ORT_ENFORCE(st.IsOK(), st.ErrorMessage()); - + auto mt = op_kernel_info.GetAttr("momentum", &momentum_); + ORT_ENFORCE(mt.IsOK(), mt.ErrorMessage()); // For opset 6-8, if spatial attribute exists, pick up the value (by default spatial == 1) // From opset 9 onwards, by default, only the spatial case (spatial == 1) is defined per spec - //TODO: momentum + // For opset 14 onwards, training is true iff we have optional outputs present + // For opset < 14, since no training attribute is present we assume optional outputs indicate training mode + is_train_ = OpKernel::Node().OutputDefs().size() > 1; + ORT_ENFORCE(!is_train_ || is_spatial_, "Training mode does not support non-spatial BN"); } Status Compute(OpKernelContext* p_op_kernel_context) const override { @@ -67,17 +71,87 @@ class BatchNorm : public OpKernel { // calculate sample_size (including all channels) size_t sample_size_incl_all_channels = sample_size * C; + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(p_op_kernel_context->GetTempSpaceAllocator(&alloc)); + + // Saved mean corresponds to the mean from this batch + // If these optional outputs are present (opset <= 9 or internal BN op) we re-use the space for calculations + // Note that with opset <= 9 we will be outputting saved_inv_std_dev instead of saved_var + Tensor *saved_mean = is_train_ ? p_op_kernel_context->Output(3, mean->Shape()) : nullptr; + Tensor *saved_inv_std = is_train_ ? p_op_kernel_context->Output(4, var->Shape()) : nullptr; + // With opset <= 9, both must be defined in training. If opset >= 14, neither should be defined in training + ORT_ENFORCE(!is_train_ || ((!saved_mean && !saved_inv_std) || (saved_mean && saved_inv_std)), "Invalid number of outputs for BN training"); + Tensor saved_mean_allocated, saved_inv_std_allocated; + if (is_train_ && !saved_mean) { + saved_mean_allocated = Tensor(DataTypeImpl::GetType(), mean->Shape(), alloc); + saved_inv_std_allocated = Tensor(DataTypeImpl::GetType(), var->Shape(), alloc); + saved_mean = &saved_mean_allocated; + saved_inv_std = &saved_inv_std_allocated; + } + ConstEigenArrayMap X_arr(X->template Data(), + is_spatial_ ? sample_size : sample_size_incl_all_channels, + is_spatial_ ? N * C : N); ConstEigenVectorArrayMap scale_arr(scale->template Data(), is_spatial_ ? C : sample_size_incl_all_channels); ConstEigenVectorArrayMap bias_arr(B->template Data(), is_spatial_ ? C : sample_size_incl_all_channels); + // Note that we only support spatial BN for training + if (is_train_) { + EigenVectorArrayMap saved_mean_arr(saved_mean->template MutableData(), C); + // We first calculate saved_var then later take inverse square root to get saved_inv_std + EigenVectorArrayMap saved_var_arr(saved_inv_std->template MutableData(), C); + saved_mean_arr.setZero(); + saved_var_arr.setZero(); + + for (size_t nc = 0; nc < N * C; ++nc) { + saved_mean_arr(nc % C) += X_arr.col(nc).sum(); + } + + saved_mean_arr /= static_cast(N * sample_size); + for (size_t nc = 0; nc < N * C; ++nc) { + saved_var_arr(nc % C) += (X_arr.col(nc) - saved_mean_arr(nc % C)).matrix().squaredNorm(); + } + saved_var_arr /= static_cast(N * sample_size); + + // The running mean corresponds to the mean from all the batches + // During inference this running mean is used as the mean for BN + auto* running_mean = p_op_kernel_context->Output(1, mean->Shape()); + auto* running_var = p_op_kernel_context->Output(2, var->Shape()); + const auto* input_running_mean = p_op_kernel_context->Input(3); + const auto* input_running_var = p_op_kernel_context->Input(4); + + // Assume that running mean and variance are initialized properly in the model given to us + // Because we alias it, we have the past history here + EigenVectorArrayMap running_mean_arr( + running_mean->template MutableData(), C); + EigenVectorArrayMap running_var_arr( + running_var->template MutableData(), C); + ConstEigenVectorArrayMap input_running_mean_arr( + input_running_mean->template Data(), C); + ConstEigenVectorArrayMap input_running_var_arr( + input_running_var->template Data(), C); + running_mean_arr = input_running_mean_arr * momentum_ + saved_mean_arr * (1. - momentum_); + running_var_arr = input_running_var_arr * momentum_ + saved_var_arr * (1. - momentum_); + } + // Regardless of training or testing, we will apply the estimated mean // and standard deviation to the input. For testing, they are // specified directly by the input, and for training, they are computed // by the op. Eigen::Array inv_std(is_spatial_ ? C : sample_size_incl_all_channels); - ConstEigenVectorArrayMap var_arr(var->template Data(), is_spatial_ ? C : sample_size_incl_all_channels); - inv_std = (var_arr + epsilon_).sqrt().inverse(); - ConstEigenVectorArrayMap mean_arr(mean->template Data(), is_spatial_ ? C : sample_size_incl_all_channels); + + if (!is_train_) { + ConstEigenVectorArrayMap var_arr(var->template Data(), is_spatial_ ? C : sample_size_incl_all_channels); + inv_std = (var_arr + epsilon_).sqrt().inverse(); + } else { + EigenVectorArrayMap saved_inv_std_arr(saved_inv_std->template MutableData(), C); + saved_inv_std_arr = (saved_inv_std_arr + epsilon_).inverse().sqrt(); + inv_std = saved_inv_std_arr; + } + + // If we're training, do batch normalization based on computation from this batch + ConstEigenVectorArrayMap mean_arr( + !is_train_ ? mean->template Data() : saved_mean->template Data(), is_spatial_ ? C : sample_size_incl_all_channels); + // We can fuse the output computation as follows: // ((x - est_mean) * (inv_var) * scale + bias // to @@ -87,9 +161,7 @@ class BatchNorm : public OpKernel { EigenArrayMap Y_arr(Y->template MutableData(), is_spatial_ ? sample_size : sample_size_incl_all_channels, is_spatial_ ? N * C : N); - ConstEigenArrayMap X_arr(X->template Data(), - is_spatial_ ? sample_size : sample_size_incl_all_channels, - is_spatial_ ? N * C : N); + if (is_spatial_) { // spatial == 1 for (size_t nc = 0; nc < N * C; ++nc) { Y_arr.col(nc) = X_arr.col(nc) * new_scale(nc % C) + new_bias(nc % C); @@ -99,13 +171,13 @@ class BatchNorm : public OpKernel { Y_arr.col(n) = X_arr.col(n) * new_scale.col(0) + new_bias.col(0); } } - return Status::OK(); } protected: float epsilon_; + float momentum_; const bool is_spatial_; - //int64_t is_test_; ignored in this implementation since we're doing inferencing only. + int64_t is_train_; }; } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc index faf4c604bc..c1113d2786 100644 --- a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc @@ -733,16 +733,15 @@ TEST(BatchNormTest, BatchNorm2d_fp16) { test.AddOutput("output", input_shape, f_output); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +#endif -// TODO fix flaky test -TEST(BatchNormTest, DISABLED_ForwardTrainingTest) { - OpTester test("BatchNormalization"); +// TODO fix flaky test for CUDA +TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) { + OpTester test("BatchNormalization", 9); float epsilon = 1e-05f; float momentum = 0.1f; - int64_t spatial = 1; test.AddAttribute("epsilon", epsilon); test.AddAttribute("momentum", momentum); - test.AddAttribute("spatial", spatial); std::vector input_output_dims{2, 2, 2, 2}; std::vector channel_dims{2}; test.AddInput("X", input_output_dims, {-0.2953f, 0.1180f, 1.0973f, -0.1931f, -0.1999f, -0.0237f, 1.5181f, 0.0076f, -1.0830f, -1.5433f, 0.4327f, -0.9813f, 0.7875f, -0.4080f, -2.3144f, 1.5493f}); @@ -750,19 +749,46 @@ TEST(BatchNormTest, DISABLED_ForwardTrainingTest) { test.AddInput("B", channel_dims, {0.0f, 0.0f}); test.AddInput("mean", channel_dims, {1.0f, 2.0f}); test.AddInput("var", channel_dims, {1.0f, 2.0f}); - // values from PyTorch with affine=False, track_running_stats=True flags - test.AddOutput("Y", input_output_dims, {0.0131f, 0.5210f, 1.7244f, 0.1387f, -0.2708f, -0.1191f, 1.2089f, -0.0922f, -0.9548f, -1.5203f, 0.9077f, -0.8298f, 0.5796f, -0.4501f, -2.0921f, 1.2358f}); - // in PyTorch, running_mean and running_var should be initialized to [0.0, 0.0] - test.AddOutput("running_mean", channel_dims, {-0.0306f, 0.0115f}); - test.AddOutput("running_var", channel_dims, {0.0757f, 0.1541f}); - // mean and variance of X across channel dimension - test.AddOutput("saved_mean", channel_dims, {-0.306f, 0.115f}); - test.AddOutput("saved_var", channel_dims, {1.229f, 0.861f}); - // exclude CPU Execution Provider so that test is run with CUDA with ForwardTraining mode - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider}); + test.AddOutput("Y", input_output_dims, {0.0131f, 0.5210f, 1.7244f, 0.1387f, -0.2708f, -0.1191f, 1.2089f, -0.0922f, -0.9548f, -1.5203f, 0.9077f, -0.8298f, 0.5796f, -0.4501f, -2.0921f, 1.2358f}); + + test.AddOutput("running_mean", channel_dims, {-0.1754f, 0.303106f}); + test.AddOutput("running_var", channel_dims, {0.696052f, 1.41316f}); + // mean and variance of X across channel dimension + // With Opset9 we output saved_inv_std instead of saved_var to match CUDA EP + test.AddOutput("saved_mean", channel_dims, {-0.306f, 0.114562f}); + test.AddOutput("saved_inv_std", channel_dims, {1.2288f, 0.861317f}); + + // exclude CUDA Execution Provider due to flakiness + // exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm() + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(BatchNormTest, ForwardTrainingTestOpset14) { + OpTester test("BatchNormalization", 14); + float epsilon = 1e-05f; + float momentum = 0.1f; + int64_t training_mode = 1; + test.AddAttribute("epsilon", epsilon); + test.AddAttribute("momentum", momentum); + test.AddAttribute("training_mode", training_mode); + std::vector input_output_dims{2, 2, 2, 2}; + std::vector channel_dims{2}; + test.AddInput("X", input_output_dims, {-0.2953f, 0.1180f, 1.0973f, -0.1931f, -0.1999f, -0.0237f, 1.5181f, 0.0076f, -1.0830f, -1.5433f, 0.4327f, -0.9813f, 0.7875f, -0.4080f, -2.3144f, 1.5493f}); + test.AddInput("scale", channel_dims, {1.0f, 1.0f}); + test.AddInput("B", channel_dims, {0.0f, 0.0f}); + test.AddInput("mean", channel_dims, {1.0f, 2.0f}); + test.AddInput("var", channel_dims, {1.0f, 2.0f}); + + test.AddOutput("Y", input_output_dims, {0.0131f, 0.5210f, 1.7244f, 0.1387f, -0.2708f, -0.1191f, 1.2089f, -0.0922f, -0.9548f, -1.5203f, 0.9077f, -0.8298f, 0.5796f, -0.4501f, -2.0921f, 1.2358f}); + + test.AddOutput("running_mean", channel_dims, {-0.1754f, 0.303106f}); + test.AddOutput("running_var", channel_dims, {0.696052f, 1.41316f}); + + // exclude CUDA Execution Provider due to flakiness + // exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm() + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } -#endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/kernel_def_hashes/onnx.cpu.json b/onnxruntime/test/testdata/kernel_def_hashes/onnx.cpu.json index c06ab7c7f9..077dc7ca9e 100644 --- a/onnxruntime/test/testdata/kernel_def_hashes/onnx.cpu.json +++ b/onnxruntime/test/testdata/kernel_def_hashes/onnx.cpu.json @@ -247,6 +247,14 @@ "BatchNormalization ai.onnx CPUExecutionProvider", 18128921553709069152 ], + [ + "BatchNormalization ai.onnx CPUExecutionProvider", + 13094179255141648608 + ], + [ + "BatchNormalization ai.onnx CPUExecutionProvider", + 17832136363477464736 + ], [ "BitShift ai.onnx CPUExecutionProvider", 4758677670685660688 diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.h b/orttraining/orttraining/core/framework/gradient_graph_builder.h index 25967e0c77..e7b7d79d3c 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.h +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.h @@ -29,6 +29,7 @@ static std::unordered_map> STOP_GRADIENT_EDGES = { {"Not", {0}}, {"And", {0, 1}}, + {"BatchNormalization", {3, 4}}, {"Or", {0, 1}}, {"Xor", {0, 1}}, {"Equal", {0, 1}},