From 75934af8967f219eb7b67d32fda45e5ae035a782 Mon Sep 17 00:00:00 2001 From: Ke Zhang Date: Thu, 3 Jan 2019 19:24:06 -0800 Subject: [PATCH] have Im2ColNd support all types and allow customized padding value. (#273) * have Im2ColNd support all types and allow customized padding value. * only specialize the template in order NCHW. * fix build break. * fix build break --- onnxruntime/core/providers/cpu/nn/conv.cc | 10 +- onnxruntime/core/providers/cpu/nn/conv_impl.h | 10 +- onnxruntime/core/util/math.h | 108 +++++++++++++++--- onnxruntime/core/util/math_cpu.cc | 94 ++------------- 4 files changed, 114 insertions(+), 108 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/conv.cc b/onnxruntime/core/providers/cpu/nn/conv.cc index d633aa8fe1..0b437e0a96 100644 --- a/onnxruntime/core/providers/cpu/nn/conv.cc +++ b/onnxruntime/core/providers/cpu/nn/conv.cc @@ -20,15 +20,15 @@ Status Conv::Compute(OpKernelContext* context) const { if (kernel_shape.size() + 2 != W->Shape().NumDimensions()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape num_dims is not compatible with W num_dims.", - " kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(), - " W: ", W->Shape().ToString().c_str()); + " kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(), + " W: ", W->Shape().ToString().c_str()); } for (size_t i = 0; i < kernel_shape.size(); ++i) { if (kernel_shape[i] != W->Shape()[i + 2]) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape is not compatible with W shape.", - " kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(), - " W: ", W->Shape().ToString().c_str()); + " kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(), + " W: ", W->Shape().ToString().c_str()); } } @@ -111,7 +111,7 @@ Status Conv::Compute(OpKernelContext* context) const { for (int image_id = 0; image_id < N; ++image_id) { for (int group_id = 0; group_id < group_; ++group_id) { - math::Im2colNd( + math::Im2colNd()( Xdata + group_id * X_offset, image_shape.GetDims().data(), col_buffer_shape.data(), diff --git a/onnxruntime/core/providers/cpu/nn/conv_impl.h b/onnxruntime/core/providers/cpu/nn/conv_impl.h index c22e9a6e7d..36c9074bf6 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_impl.h +++ b/onnxruntime/core/providers/cpu/nn/conv_impl.h @@ -57,15 +57,15 @@ Status Conv::Compute(OpKernelContext* context) const { if (kernel_shape.size() + 2 != W->Shape().NumDimensions()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape num_dims is not compatible with W num_dims.", - " kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(), - " W: ", W->Shape().ToString().c_str()); + " kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(), + " W: ", W->Shape().ToString().c_str()); } for (size_t i = 0; i < kernel_shape.size(); ++i) { if (kernel_shape[i] != W->Shape()[i + 2]) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape is not compatible with W shape.", - " kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(), - " W: ", W->Shape().ToString().c_str()); + " kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(), + " W: ", W->Shape().ToString().c_str()); } } @@ -135,7 +135,7 @@ Status Conv::Compute(OpKernelContext* context) const { col_buffer_data, &CPUMathUtil::Instance()); } else { - math::Im2colNd( + math::Im2colNd()( Xdata + group_id * X_offset, image_shape.GetDims().data(), col_buffer_shape.data(), diff --git a/onnxruntime/core/util/math.h b/onnxruntime/core/util/math.h index 9e84a41d04..71ea37f39e 100644 --- a/onnxruntime/core/util/math.h +++ b/onnxruntime/core/util/math.h @@ -327,20 +327,100 @@ void Axpby( Provider* provider); template -void Im2colNd( - const T* data_img, - const int64_t* im_shape, - const int64_t* col_shape, - const int64_t img_size, - const int64_t col_size, - const int64_t* kernel_shape, - const int64_t* stride, - const int64_t* dilation, - const int64_t* pad, - const int64_t N, - T* data_col, - Provider* provider, - bool accumulate_output = false); +struct Im2colNd { + void operator()( + const T* data_img, + const int64_t* im_shape, + const int64_t* col_shape, + const int64_t img_size, + const int64_t col_size, + const int64_t* kernel_shape, + const int64_t* stride, + const int64_t* dilation, + const int64_t* pad, + const int64_t N, + T* data_col, + Provider* /*provider*/, + bool accumulate_output = false, + T padding_value = 0); +}; + +template +struct Im2colNd { + void operator()( + const T* data_img, + const int64_t* im_shape, + const int64_t* col_shape, + const int64_t /*img_size*/, + const int64_t /*col_size*/, + const int64_t* kernel_shape, + const int64_t* stride, + const int64_t* dilation, + const int64_t* pad, + const int64_t N, + T* data_col, + Provider* /*provider*/, + bool accumulate_output = false, + T padding_value = 0) { + int64_t kernel_size = 1; + for (int64_t i = 0; i < N; ++i) { + kernel_size *= kernel_shape[i]; + } + const int64_t channels_col = col_shape[0]; + std::vector d_offset(N, 0); + std::vector d_iter(N, 0); + for (int64_t c_col = 0; c_col < channels_col; ++c_col) { + // Loop over spatial axes in reverse order to compute a per-axis offset. + int64_t offset = c_col; + for (int64_t d_i = N - 1; d_i >= 0; --d_i) { + if (d_i < N - 1) { + offset /= kernel_shape[d_i + 1]; + } + d_offset[d_i] = offset % kernel_shape[d_i]; + } + for (bool incremented = true; incremented;) { + // Loop over spatial axes in forward order to compute the indices in the + // image and column, and whether the index lies in the padding. + int64_t index_col = c_col; + int64_t index_im = c_col / kernel_size; + bool is_padding = false; + for (int64_t d_i = 0; d_i < N; ++d_i) { + const int64_t d = d_iter[d_i]; + const int64_t d_im = + d * stride[d_i] - pad[d_i] + d_offset[d_i] * dilation[d_i]; + is_padding |= d_im < 0 || d_im >= im_shape[d_i + 1]; + index_col *= col_shape[d_i + 1]; + index_col += d; + index_im *= im_shape[d_i + 1]; + index_im += d_im; + } + if (!accumulate_output) { + if (is_padding) { + data_col[index_col] = padding_value; + } else { + data_col[index_col] = data_img[index_im]; + } + } else if (!is_padding) { // col2im + data_col[index_im] += data_img[index_col]; + } + // Loop over spatial axes in reverse order to choose an index, + // like counting. + incremented = false; + for (int64_t d_i = N - 1; d_i >= 0; --d_i) { + const int64_t d_max = col_shape[d_i + 1]; + ORT_ENFORCE(d_iter[d_i] < d_max); + if (d_iter[d_i] == d_max - 1) { + d_iter[d_i] = 0; + } else { // d_iter[d_i] < d_max - 1 + ++d_iter[d_i]; + incremented = true; + break; + } + } + } // while(incremented) { + } // for (int c = 0; c < channels_col; ++c) { + } +}; template void Col2imNd( diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index fc149c415a..be0ca16b0c 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -475,15 +475,15 @@ void GemmBatched( } } -// MKL will be implmenet as an execution provider -//////////////////////////////////////////////////////////////////////////////// -// MKL VML alternatives. -// Depending on whether we are using MKL, we will delegate the Caffe math -// functions that are VML-related to either the VML call or the Eigen -// implementation. If you are setting the flags (such as AVX) right for your CPU -// architecture, usually Eigen will deliver a throughput as fast as the VML -// functions. -//////////////////////////////////////////////////////////////////////////////// + // MKL will be implmenet as an execution provider + //////////////////////////////////////////////////////////////////////////////// + // MKL VML alternatives. + // Depending on whether we are using MKL, we will delegate the Caffe math + // functions that are VML-related to either the VML call or the Eigen + // implementation. If you are setting the flags (such as AVX) right for your CPU + // architecture, usually Eigen will deliver a throughput as fast as the VML + // functions. + //////////////////////////////////////////////////////////////////////////////// #define DELEGATE_SIMPLE_UNARY_FUNCTION(T, Funcname, expr) \ template <> \ @@ -859,80 +859,6 @@ void Select( y[i] = x[i * D + idx[i]]; } } -// Ported from caffe 1. -template <> -void Im2colNd( - const float* data_img, - const int64_t* im_shape, - const int64_t* col_shape, - const int64_t /* img_size*/, - const int64_t /* col_size*/, - const int64_t* kernel_shape, - const int64_t* stride, - const int64_t* dilation, - const int64_t* pad, - const int64_t N, - float* data_col, - CPUMathUtil* /* context */, - bool accumulate_output) { - int64_t kernel_size = 1; - for (int64_t i = 0; i < N; ++i) { - kernel_size *= kernel_shape[i]; - } - const int64_t channels_col = col_shape[0]; - std::vector d_offset(N, 0); - std::vector d_iter(N, 0); - for (int64_t c_col = 0; c_col < channels_col; ++c_col) { - // Loop over spatial axes in reverse order to compute a per-axis offset. - int64_t offset = c_col; - for (int64_t d_i = N - 1; d_i >= 0; --d_i) { - if (d_i < N - 1) { - offset /= kernel_shape[d_i + 1]; - } - d_offset[d_i] = offset % kernel_shape[d_i]; - } - for (bool incremented = true; incremented;) { - // Loop over spatial axes in forward order to compute the indices in the - // image and column, and whether the index lies in the padding. - int64_t index_col = c_col; - int64_t index_im = c_col / kernel_size; - bool is_padding = false; - for (int64_t d_i = 0; d_i < N; ++d_i) { - const int64_t d = d_iter[d_i]; - const int64_t d_im = - d * stride[d_i] - pad[d_i] + d_offset[d_i] * dilation[d_i]; - is_padding |= d_im < 0 || d_im >= im_shape[d_i + 1]; - index_col *= col_shape[d_i + 1]; - index_col += d; - index_im *= im_shape[d_i + 1]; - index_im += d_im; - } - if (!accumulate_output) { - if (is_padding) { - data_col[index_col] = 0; - } else { - data_col[index_col] = data_img[index_im]; - } - } else if (!is_padding) { // col2im - data_col[index_im] += data_img[index_col]; - } - // Loop over spatial axes in reverse order to choose an index, - // like counting. - incremented = false; - for (int64_t d_i = N - 1; d_i >= 0; --d_i) { - const int64_t d_max = col_shape[d_i + 1]; - ORT_ENFORCE(d_iter[d_i] < d_max); - if (d_iter[d_i] == d_max - 1) { - d_iter[d_i] = 0; - } else { // d_iter[d_i] < d_max - 1 - ++d_iter[d_i]; - incremented = true; - break; - } - } - } // while(incremented) { - } // for (int c = 0; c < channels_col; ++c) { -} template <> void Col2imNd( @@ -949,7 +875,7 @@ void Col2imNd( float* data_img, CPUMathUtil* context) { Set(img_size, 0, data_img, context); - Im2colNd( + Im2colNd()( data_col, img_shape, col_shape,