From 1e4080061bb6310fa24d7d02e2223a19b3dc186a Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 30 Jan 2020 13:54:38 -0800 Subject: [PATCH] Added support for double in batch norm (#2941) --- .../providers/cpu/cpu_execution_provider.cc | 12 +-- .../core/providers/cpu/nn/batch_norm.cc | 88 ++----------------- .../core/providers/cpu/nn/batch_norm.h | 68 +++++++++++++- .../providers/cpu/nn/batch_norm_op_test.cc | 54 ++++++++++-- 4 files changed, 122 insertions(+), 100 deletions(-) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index f625e82669..ff06f7fdb3 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -114,7 +114,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, Softmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, Softmax); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, TopK); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, double, BatchNormalization); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Conv); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, Flatten); @@ -275,7 +276,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, uint8_t, Where); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, BatchNormalization); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, double, MatMul); @@ -578,8 +578,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo()) - .TypeConstraint("scale", DataTypeImpl::GetTensorType()) - .TypeConstraint("B", DataTypeImpl::GetTensorType()) - .TypeConstraint("mean", DataTypeImpl::GetTensorType()) - .TypeConstraint("var", DataTypeImpl::GetTensorType()), - BatchNorm); -// 'spatial' attribute was removed. -ONNX_CPU_OPERATOR_KERNEL( - BatchNormalization, - 9, - KernelDefBuilder() - .TypeConstraint("X", DataTypeImpl::GetTensorType()) - .TypeConstraint("scale", DataTypeImpl::GetTensorType()) - .TypeConstraint("B", DataTypeImpl::GetTensorType()) - .TypeConstraint("mean", DataTypeImpl::GetTensorType()) - .TypeConstraint("var", DataTypeImpl::GetTensorType()), - BatchNorm); +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 7, 9, float, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + BatchNorm); -template <> -Status BatchNorm::Compute(OpKernelContext* p_op_kernel_context) const { - const auto* X = p_op_kernel_context->Input(0); - const auto* scale = p_op_kernel_context->Input(1); - const auto* B = p_op_kernel_context->Input(2); - const auto* mean = p_op_kernel_context->Input(3); - const auto* var = p_op_kernel_context->Input(4); +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 7, 9, double, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + BatchNorm); - ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, scale, B, mean, var, is_spatial_)); - - const TensorShape& x_shape = X->Shape(); - Tensor* Y = p_op_kernel_context->Output(0, x_shape); - - const auto& dims_vec = x_shape.GetDims(); - const size_t N = dims_vec[0]; - const size_t C = dims_vec[1]; // assume NCHW as per the spec - - // calculate sample_size (per individual channel) - size_t sample_size = 1; - for (size_t i = 2; i < dims_vec.size(); ++i) { - sample_size *= dims_vec[i]; - } - - // calculate sample_size (including all channels) - size_t sample_size_incl_all_channels = sample_size * C; - - ConstEigenVectorArrayMap scale_arr(scale->template Data(), is_spatial_ ? C : sample_size_incl_all_channels); - ConstEigenVectorArrayMap bias_arr(B->template Data(), is_spatial_ ? C : sample_size_incl_all_channels); - - // Regardless of training or testing, we will apply the estimated mean - // and standard deviation to the input. For testing, they are - // specified directly by the input, and for training, they are computed - // by the op. - Eigen::Array inv_std(is_spatial_ ? C : sample_size_incl_all_channels); - ConstEigenVectorArrayMap var_arr(var->template Data(), is_spatial_ ? C : sample_size_incl_all_channels); - inv_std = (var_arr + epsilon_).sqrt().inverse(); - ConstEigenVectorArrayMap mean_arr(mean->template Data(), is_spatial_ ? C : sample_size_incl_all_channels); - // We can fuse the output computation as follows: - // ((x - est_mean) * (inv_var) * scale + bias - // to - // (x * inv_var * scale) + (bias - est_mean * inv_var * scale) - Eigen::Array new_scale = inv_std * scale_arr; - Eigen::Array new_bias = bias_arr - mean_arr * new_scale; - EigenArrayMap Y_arr(Y->template MutableData(), - is_spatial_ ? sample_size : sample_size_incl_all_channels, - is_spatial_ ? N * C : N); - ConstEigenArrayMap X_arr(X->template Data(), - is_spatial_ ? sample_size : sample_size_incl_all_channels, - is_spatial_ ? N * C : N); - if (is_spatial_) { // spatial == 1 - for (size_t nc = 0; nc < N * C; ++nc) { - Y_arr.col(nc) = X_arr.col(nc) * new_scale(nc % C) + new_bias(nc % C); - } - } else { // spatial == 0 - for (size_t n = 0; n < N; ++n) { - Y_arr.col(n) = X_arr.col(n) * new_scale.col(0) + new_bias.col(0); - } - } - - return Status::OK(); -} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm.h b/onnxruntime/core/providers/cpu/nn/batch_norm.h index 572f2b34ee..35a5a70b3e 100644 --- a/onnxruntime/core/providers/cpu/nn/batch_norm.h +++ b/onnxruntime/core/providers/cpu/nn/batch_norm.h @@ -23,28 +23,88 @@ #include "core/providers/cpu/nn/autopad_type.h" #include "core/framework/tensor.h" #include "core/util/math_cpuonly.h" +#include "core/providers/cpu/nn/batch_norm_helper.h" namespace onnxruntime { template class BatchNorm : public OpKernel { public: - explicit BatchNorm(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) { + explicit BatchNorm(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info), + is_spatial_(op_kernel_info.GetAttrOrDefault("spatial", 1) == 1) { auto st = op_kernel_info.GetAttr("epsilon", &epsilon_); ORT_ENFORCE(st.IsOK(), st.ErrorMessage()); // For opset 6-8, if spatial attribute exists, pick up the value (by default spatial == 1) // From opset 9 onwards, by default, only the spatial case (spatial == 1) is defined per spec - is_spatial_ = op_kernel_info.GetAttrOrDefault("spatial", 1) == 1 ? true : false; //TODO: momentum } - Status Compute(OpKernelContext* p_op_kernel_context) const override; + Status Compute(OpKernelContext* p_op_kernel_context) const override { + const auto* X = p_op_kernel_context->Input(0); + const auto* scale = p_op_kernel_context->Input(1); + const auto* B = p_op_kernel_context->Input(2); + const auto* mean = p_op_kernel_context->Input(3); + const auto* var = p_op_kernel_context->Input(4); + + ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, scale, B, mean, var, is_spatial_)); + + const TensorShape& x_shape = X->Shape(); + Tensor* Y = p_op_kernel_context->Output(0, x_shape); + + const auto& dims_vec = x_shape.GetDims(); + const size_t N = dims_vec[0]; + const size_t C = dims_vec[1]; // assume NCHW as per the spec + + // calculate sample_size (per individual channel) + size_t sample_size = 1; + for (size_t i = 2; i < dims_vec.size(); ++i) { + sample_size *= dims_vec[i]; + } + + // calculate sample_size (including all channels) + size_t sample_size_incl_all_channels = sample_size * C; + + ConstEigenVectorArrayMap scale_arr(scale->template Data(), is_spatial_ ? C : sample_size_incl_all_channels); + ConstEigenVectorArrayMap bias_arr(B->template Data(), is_spatial_ ? C : sample_size_incl_all_channels); + + // Regardless of training or testing, we will apply the estimated mean + // and standard deviation to the input. For testing, they are + // specified directly by the input, and for training, they are computed + // by the op. + Eigen::Array inv_std(is_spatial_ ? C : sample_size_incl_all_channels); + ConstEigenVectorArrayMap var_arr(var->template Data(), is_spatial_ ? C : sample_size_incl_all_channels); + inv_std = (var_arr + epsilon_).sqrt().inverse(); + ConstEigenVectorArrayMap mean_arr(mean->template Data(), is_spatial_ ? C : sample_size_incl_all_channels); + // We can fuse the output computation as follows: + // ((x - est_mean) * (inv_var) * scale + bias + // to + // (x * inv_var * scale) + (bias - est_mean * inv_var * scale) + Eigen::Array new_scale = inv_std * scale_arr; + Eigen::Array new_bias = bias_arr - mean_arr * new_scale; + EigenArrayMap Y_arr(Y->template MutableData(), + is_spatial_ ? sample_size : sample_size_incl_all_channels, + is_spatial_ ? N * C : N); + ConstEigenArrayMap X_arr(X->template Data(), + is_spatial_ ? sample_size : sample_size_incl_all_channels, + is_spatial_ ? N * C : N); + if (is_spatial_) { // spatial == 1 + for (size_t nc = 0; nc < N * C; ++nc) { + Y_arr.col(nc) = X_arr.col(nc) * new_scale(nc % C) + new_bias(nc % C); + } + } else { // spatial == 0 + for (size_t n = 0; n < N; ++n) { + Y_arr.col(n) = X_arr.col(n) * new_scale.col(0) + new_bias.col(0); + } + } + + return Status::OK(); + } protected: float epsilon_; - bool is_spatial_; + const bool is_spatial_; //int64_t is_test_; ignored in this implementation since we're doing inferencing only. }; } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc index 9b862712bb..0452659687 100644 --- a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc @@ -16,10 +16,11 @@ namespace test { using InputDataMap = unordered_map>; using InputShapesMap = unordered_map>; -void TestBatchNorm(const InputDataMap& input_data_map, +template +void TestBatchNorm(const unordered_map>& input_data_map, const InputShapesMap& input_shapes_map, optional epsilon, - const std::initializer_list& expected_output, + const std::initializer_list& expected_output, const vector& expected_output_shape, int64_t spatial_mode = 1, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, @@ -32,12 +33,12 @@ void TestBatchNorm(const InputDataMap& input_data_map, if (opset_version < 9) { // spatial is only defined for opset-8 and below in the spec test.AddAttribute("spatial", spatial_mode); } - test.AddInput("X", input_shapes_map.at("X"), input_data_map.at("X")); - test.AddInput("scale", input_shapes_map.at("scale"), input_data_map.at("scale")); - test.AddInput("B", input_shapes_map.at("B"), input_data_map.at("B")); - test.AddInput("mean", input_shapes_map.at("mean"), input_data_map.at("mean")); - test.AddInput("var", input_shapes_map.at("var"), input_data_map.at("var")); - test.AddOutput("output", expected_output_shape, expected_output); + test.AddInput("X", input_shapes_map.at("X"), input_data_map.at("X")); + test.AddInput("scale", input_shapes_map.at("scale"), input_data_map.at("scale")); + test.AddInput("B", input_shapes_map.at("B"), input_data_map.at("B")); + test.AddInput("mean", input_shapes_map.at("mean"), input_data_map.at("mean")); + test.AddInput("var", input_shapes_map.at("var"), input_data_map.at("var")); + test.AddOutput("output", expected_output_shape, expected_output); // Weight as input is not supported by TensorRT and spatial == 0 is not supported by Nuphar std::unordered_set excluded_eps = {kTensorrtExecutionProvider}; if (spatial_mode == 0) { @@ -83,6 +84,43 @@ TEST(BatchNormTest, PositiveTestCase) { TestBatchNorm(input_data_map, input_shapes_map, epsilon, expected_output, input_shape); } +TEST(BatchNormTest, PositiveTestCaseDouble) { + // This input was taken from the SpatialBN_1.pb, SpatialBN_1_input.pb and SpatialBN_1_output.pb files. + vector X{0.329876f, -0.287158f, -0.411425f, 0.473621f, 0.18156f, -0.170596f, -0.329516f, -0.170733f, -0.121664f, 0.4372f, + -0.485668f, 0.218049f, -0.360263f, 0.107016f, 0.45358f, 0.325056f, 0.15995f, 0.098852f, -0.283453f, -0.373051f, + 0.257542f, 0.0614853f, -0.0592363f, 0.434488f, -0.0179583f, 0.398374f, -0.451602f, -0.132009f, -0.174468f, + -0.0247169f, 0.418897f, -0.47159f, -0.131925f, 0.470943f, 0.118357f, 0.155664f, 0.370062f, -0.279229f, 0.240311f, + -0.451034f, 0.249178f, -0.294496f, 0.13683f, -0.0806475f, -0.309849f, -0.450604f, -0.28048f, -0.420197f, -0.433369f}; + vector scale{0.589433f}; + vector B{-0.384622f}; + vector mean{-2.45673f}; + vector var{1.37998f}; + + unordered_map> input_data_map; + input_data_map.insert({"X", X}); + input_data_map.insert({"scale", scale}); + input_data_map.insert({"B", B}); + input_data_map.insert({"mean", mean}); + input_data_map.insert({"var", var}); + + InputShapesMap input_shapes_map; + vector input_shape{1, 1, 7, 7, 1}; + input_shapes_map.insert({"X", input_shape}); + input_shapes_map.insert({"scale", {1}}); + input_shapes_map.insert({"B", {1}}); + input_shapes_map.insert({"mean", {1}}); + input_shapes_map.insert({"var", {1}}); + + const std::initializer_list expected_output = {1.01359f, 0.703983f, 0.641631f, 1.08571f, 0.939167f, 0.762469f, 0.682729f, 0.762401f, 0.787021f, + 1.06744f, 0.604378f, 0.957476f, 0.667302f, 0.901764f, 1.07566f, 1.01117f, 0.928324f, 0.897667f, + 0.705842f, 0.660885f, 0.977291f, 0.878918f, 0.818345f, 1.06608f, 0.839057f, 1.04796f, 0.621471f, + 0.781831f, 0.760527f, 0.835665f, 1.05825f, 0.611442f, 0.781873f, 1.08437f, 0.907454f, 0.926173f, + 1.03375f, 0.707961f, 0.968646f, 0.621757f, 0.973095f, 0.700301f, 0.916723f, 0.807602f, 0.692598f, + 0.621972f, 0.707334f, 0.63723f, 0.63062f}; + float epsilon = 1e-05f; + TestBatchNorm(input_data_map, input_shapes_map, epsilon, expected_output, input_shape); +} + TEST(BatchNormTest, PositiveTestCaseDefaultEpsilon) { // This input was taken from the SpatialBN_1.pb, SpatialBN_1_input.pb and SpatialBN_1_output.pb files from an older version of this project vector X{0.329876f, -0.287158f, -0.411425f, 0.473621f, 0.18156f, -0.170596f, -0.329516f, -0.170733f, -0.121664f, 0.4372f,