Added support for double in batch norm (#2941)

This commit is contained in:
Changming Sun 2020-01-30 13:54:38 -08:00 committed by GitHub
parent 51595d6a4a
commit 1e4080061b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 122 additions and 100 deletions

View file

@ -114,7 +114,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, Softmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, Softmax);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, TopK);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, double, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Conv);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, Flatten);
@ -275,7 +276,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Where);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, uint8_t, Where);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, double, MatMul);
@ -578,8 +578,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
double, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, TopK)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8,
BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9,
float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9,
double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
Conv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10,
@ -907,8 +909,6 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
Where)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10,
Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9,
BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10,
Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float,

View file

@ -20,89 +20,13 @@
namespace onnxruntime {
// spec: https://github.com/onnx/onnx/blob/master/docs/Operators.md#BatchNormalization
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
BatchNormalization,
7,
8,
KernelDefBuilder()
.TypeConstraint("X", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("scale", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("B", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("mean", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("var", DataTypeImpl::GetTensorType<float>()),
BatchNorm<float>);
// 'spatial' attribute was removed.
ONNX_CPU_OPERATOR_KERNEL(
BatchNormalization,
9,
KernelDefBuilder()
.TypeConstraint("X", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("scale", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("B", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("mean", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("var", DataTypeImpl::GetTensorType<float>()),
BatchNorm<float>);
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 7, 9, float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
BatchNorm<float>);
template <>
Status BatchNorm<float>::Compute(OpKernelContext* p_op_kernel_context) const {
const auto* X = p_op_kernel_context->Input<Tensor>(0);
const auto* scale = p_op_kernel_context->Input<Tensor>(1);
const auto* B = p_op_kernel_context->Input<Tensor>(2);
const auto* mean = p_op_kernel_context->Input<Tensor>(3);
const auto* var = p_op_kernel_context->Input<Tensor>(4);
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 7, 9, double,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()),
BatchNorm<double>);
ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, scale, B, mean, var, is_spatial_));
const TensorShape& x_shape = X->Shape();
Tensor* Y = p_op_kernel_context->Output(0, x_shape);
const auto& dims_vec = x_shape.GetDims();
const size_t N = dims_vec[0];
const size_t C = dims_vec[1]; // assume NCHW as per the spec
// calculate sample_size (per individual channel)
size_t sample_size = 1;
for (size_t i = 2; i < dims_vec.size(); ++i) {
sample_size *= dims_vec[i];
}
// calculate sample_size (including all channels)
size_t sample_size_incl_all_channels = sample_size * C;
ConstEigenVectorArrayMap<float> scale_arr(scale->template Data<float>(), is_spatial_ ? C : sample_size_incl_all_channels);
ConstEigenVectorArrayMap<float> bias_arr(B->template Data<float>(), is_spatial_ ? C : sample_size_incl_all_channels);
// Regardless of training or testing, we will apply the estimated mean
// and standard deviation to the input. For testing, they are
// specified directly by the input, and for training, they are computed
// by the op.
Eigen::Array<float, Eigen::Dynamic, 1> inv_std(is_spatial_ ? C : sample_size_incl_all_channels);
ConstEigenVectorArrayMap<float> var_arr(var->template Data<float>(), is_spatial_ ? C : sample_size_incl_all_channels);
inv_std = (var_arr + epsilon_).sqrt().inverse();
ConstEigenVectorArrayMap<float> mean_arr(mean->template Data<float>(), is_spatial_ ? C : sample_size_incl_all_channels);
// We can fuse the output computation as follows:
// ((x - est_mean) * (inv_var) * scale + bias
// to
// (x * inv_var * scale) + (bias - est_mean * inv_var * scale)
Eigen::Array<float, Eigen::Dynamic, 1> new_scale = inv_std * scale_arr;
Eigen::Array<float, Eigen::Dynamic, 1> new_bias = bias_arr - mean_arr * new_scale;
EigenArrayMap<float> Y_arr(Y->template MutableData<float>(),
is_spatial_ ? sample_size : sample_size_incl_all_channels,
is_spatial_ ? N * C : N);
ConstEigenArrayMap<float> X_arr(X->template Data<float>(),
is_spatial_ ? sample_size : sample_size_incl_all_channels,
is_spatial_ ? N * C : N);
if (is_spatial_) { // spatial == 1
for (size_t nc = 0; nc < N * C; ++nc) {
Y_arr.col(nc) = X_arr.col(nc) * new_scale(nc % C) + new_bias(nc % C);
}
} else { // spatial == 0
for (size_t n = 0; n < N; ++n) {
Y_arr.col(n) = X_arr.col(n) * new_scale.col(0) + new_bias.col(0);
}
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -23,28 +23,88 @@
#include "core/providers/cpu/nn/autopad_type.h"
#include "core/framework/tensor.h"
#include "core/util/math_cpuonly.h"
#include "core/providers/cpu/nn/batch_norm_helper.h"
namespace onnxruntime {
template <typename T>
class BatchNorm : public OpKernel {
public:
explicit BatchNorm(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) {
explicit BatchNorm(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info),
is_spatial_(op_kernel_info.GetAttrOrDefault<int64_t>("spatial", 1) == 1) {
auto st = op_kernel_info.GetAttr<float>("epsilon", &epsilon_);
ORT_ENFORCE(st.IsOK(), st.ErrorMessage());
// For opset 6-8, if spatial attribute exists, pick up the value (by default spatial == 1)
// From opset 9 onwards, by default, only the spatial case (spatial == 1) is defined per spec
is_spatial_ = op_kernel_info.GetAttrOrDefault<int64_t>("spatial", 1) == 1 ? true : false;
//TODO: momentum
}
Status Compute(OpKernelContext* p_op_kernel_context) const override;
Status Compute(OpKernelContext* p_op_kernel_context) const override {
const auto* X = p_op_kernel_context->Input<Tensor>(0);
const auto* scale = p_op_kernel_context->Input<Tensor>(1);
const auto* B = p_op_kernel_context->Input<Tensor>(2);
const auto* mean = p_op_kernel_context->Input<Tensor>(3);
const auto* var = p_op_kernel_context->Input<Tensor>(4);
ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, scale, B, mean, var, is_spatial_));
const TensorShape& x_shape = X->Shape();
Tensor* Y = p_op_kernel_context->Output(0, x_shape);
const auto& dims_vec = x_shape.GetDims();
const size_t N = dims_vec[0];
const size_t C = dims_vec[1]; // assume NCHW as per the spec
// calculate sample_size (per individual channel)
size_t sample_size = 1;
for (size_t i = 2; i < dims_vec.size(); ++i) {
sample_size *= dims_vec[i];
}
// calculate sample_size (including all channels)
size_t sample_size_incl_all_channels = sample_size * C;
ConstEigenVectorArrayMap<T> scale_arr(scale->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
ConstEigenVectorArrayMap<T> bias_arr(B->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
// Regardless of training or testing, we will apply the estimated mean
// and standard deviation to the input. For testing, they are
// specified directly by the input, and for training, they are computed
// by the op.
Eigen::Array<T, Eigen::Dynamic, 1> inv_std(is_spatial_ ? C : sample_size_incl_all_channels);
ConstEigenVectorArrayMap<T> var_arr(var->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
inv_std = (var_arr + epsilon_).sqrt().inverse();
ConstEigenVectorArrayMap<T> mean_arr(mean->template Data<T>(), is_spatial_ ? C : sample_size_incl_all_channels);
// We can fuse the output computation as follows:
// ((x - est_mean) * (inv_var) * scale + bias
// to
// (x * inv_var * scale) + (bias - est_mean * inv_var * scale)
Eigen::Array<T, Eigen::Dynamic, 1> new_scale = inv_std * scale_arr;
Eigen::Array<T, Eigen::Dynamic, 1> new_bias = bias_arr - mean_arr * new_scale;
EigenArrayMap<T> Y_arr(Y->template MutableData<T>(),
is_spatial_ ? sample_size : sample_size_incl_all_channels,
is_spatial_ ? N * C : N);
ConstEigenArrayMap<T> X_arr(X->template Data<T>(),
is_spatial_ ? sample_size : sample_size_incl_all_channels,
is_spatial_ ? N * C : N);
if (is_spatial_) { // spatial == 1
for (size_t nc = 0; nc < N * C; ++nc) {
Y_arr.col(nc) = X_arr.col(nc) * new_scale(nc % C) + new_bias(nc % C);
}
} else { // spatial == 0
for (size_t n = 0; n < N; ++n) {
Y_arr.col(n) = X_arr.col(n) * new_scale.col(0) + new_bias.col(0);
}
}
return Status::OK();
}
protected:
float epsilon_;
bool is_spatial_;
const bool is_spatial_;
//int64_t is_test_; ignored in this implementation since we're doing inferencing only.
};
} // namespace onnxruntime

View file

@ -16,10 +16,11 @@ namespace test {
using InputDataMap = unordered_map<string, vector<float>>;
using InputShapesMap = unordered_map<string, vector<int64_t>>;
void TestBatchNorm(const InputDataMap& input_data_map,
template <typename T>
void TestBatchNorm(const unordered_map<string, vector<T>>& input_data_map,
const InputShapesMap& input_shapes_map,
optional<float> epsilon,
const std::initializer_list<float>& expected_output,
const std::initializer_list<T>& expected_output,
const vector<int64_t>& expected_output_shape,
int64_t spatial_mode = 1,
OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess,
@ -32,12 +33,12 @@ void TestBatchNorm(const InputDataMap& input_data_map,
if (opset_version < 9) { // spatial is only defined for opset-8 and below in the spec
test.AddAttribute("spatial", spatial_mode);
}
test.AddInput<float>("X", input_shapes_map.at("X"), input_data_map.at("X"));
test.AddInput<float>("scale", input_shapes_map.at("scale"), input_data_map.at("scale"));
test.AddInput<float>("B", input_shapes_map.at("B"), input_data_map.at("B"));
test.AddInput<float>("mean", input_shapes_map.at("mean"), input_data_map.at("mean"));
test.AddInput<float>("var", input_shapes_map.at("var"), input_data_map.at("var"));
test.AddOutput<float>("output", expected_output_shape, expected_output);
test.AddInput<T>("X", input_shapes_map.at("X"), input_data_map.at("X"));
test.AddInput<T>("scale", input_shapes_map.at("scale"), input_data_map.at("scale"));
test.AddInput<T>("B", input_shapes_map.at("B"), input_data_map.at("B"));
test.AddInput<T>("mean", input_shapes_map.at("mean"), input_data_map.at("mean"));
test.AddInput<T>("var", input_shapes_map.at("var"), input_data_map.at("var"));
test.AddOutput<T>("output", expected_output_shape, expected_output);
// Weight as input is not supported by TensorRT and spatial == 0 is not supported by Nuphar
std::unordered_set<std::string> excluded_eps = {kTensorrtExecutionProvider};
if (spatial_mode == 0) {
@ -83,6 +84,43 @@ TEST(BatchNormTest, PositiveTestCase) {
TestBatchNorm(input_data_map, input_shapes_map, epsilon, expected_output, input_shape);
}
TEST(BatchNormTest, PositiveTestCaseDouble) {
// This input was taken from the SpatialBN_1.pb, SpatialBN_1_input.pb and SpatialBN_1_output.pb files.
vector<double> X{0.329876f, -0.287158f, -0.411425f, 0.473621f, 0.18156f, -0.170596f, -0.329516f, -0.170733f, -0.121664f, 0.4372f,
-0.485668f, 0.218049f, -0.360263f, 0.107016f, 0.45358f, 0.325056f, 0.15995f, 0.098852f, -0.283453f, -0.373051f,
0.257542f, 0.0614853f, -0.0592363f, 0.434488f, -0.0179583f, 0.398374f, -0.451602f, -0.132009f, -0.174468f,
-0.0247169f, 0.418897f, -0.47159f, -0.131925f, 0.470943f, 0.118357f, 0.155664f, 0.370062f, -0.279229f, 0.240311f,
-0.451034f, 0.249178f, -0.294496f, 0.13683f, -0.0806475f, -0.309849f, -0.450604f, -0.28048f, -0.420197f, -0.433369f};
vector<double> scale{0.589433f};
vector<double> B{-0.384622f};
vector<double> mean{-2.45673f};
vector<double> var{1.37998f};
unordered_map<string, vector<double>> input_data_map;
input_data_map.insert({"X", X});
input_data_map.insert({"scale", scale});
input_data_map.insert({"B", B});
input_data_map.insert({"mean", mean});
input_data_map.insert({"var", var});
InputShapesMap input_shapes_map;
vector<int64_t> input_shape{1, 1, 7, 7, 1};
input_shapes_map.insert({"X", input_shape});
input_shapes_map.insert({"scale", {1}});
input_shapes_map.insert({"B", {1}});
input_shapes_map.insert({"mean", {1}});
input_shapes_map.insert({"var", {1}});
const std::initializer_list<double> expected_output = {1.01359f, 0.703983f, 0.641631f, 1.08571f, 0.939167f, 0.762469f, 0.682729f, 0.762401f, 0.787021f,
1.06744f, 0.604378f, 0.957476f, 0.667302f, 0.901764f, 1.07566f, 1.01117f, 0.928324f, 0.897667f,
0.705842f, 0.660885f, 0.977291f, 0.878918f, 0.818345f, 1.06608f, 0.839057f, 1.04796f, 0.621471f,
0.781831f, 0.760527f, 0.835665f, 1.05825f, 0.611442f, 0.781873f, 1.08437f, 0.907454f, 0.926173f,
1.03375f, 0.707961f, 0.968646f, 0.621757f, 0.973095f, 0.700301f, 0.916723f, 0.807602f, 0.692598f,
0.621972f, 0.707334f, 0.63723f, 0.63062f};
float epsilon = 1e-05f;
TestBatchNorm(input_data_map, input_shapes_map, epsilon, expected_output, input_shape);
}
TEST(BatchNormTest, PositiveTestCaseDefaultEpsilon) {
// This input was taken from the SpatialBN_1.pb, SpatialBN_1_input.pb and SpatialBN_1_output.pb files from an older version of this project
vector<float> X{0.329876f, -0.287158f, -0.411425f, 0.473621f, 0.18156f, -0.170596f, -0.329516f, -0.170733f, -0.121664f, 0.4372f,