Add Col2Im CPU op (#12311)

**Description**
This PR implements N-dimensional Col2Im as a contrib CPU Op as specified
by ONNX's https://github.com/onnx/onnx/pull/3948

**Motivation and Context**
- Col2Im enables models such as:
  - [SS-DCNet](https://github.com/xhp-hust-2018-2011/SS-DCNet)
  - [DSTT](https://github.com/ruiliu-ai/DSTT)
- It also serves to document the ORT's obscure `math::Col2ImNd` utility

Signed-off-by: Liqun Fu <liqfu@microsoft.com>
Co-authored-by: Liqun Fu <liqfu@microsoft.com>
This commit is contained in:
Thiago Crepaldi 2023-01-25 15:23:00 -05:00 committed by GitHub
parent 94b1791974
commit 32c05fcdd1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 317 additions and 1 deletions

View file

@ -58,6 +58,7 @@ Do not modify directly.*
|||12|**T** = tensor(double), tensor(float), tensor(int64), tensor(int8), tensor(uint64), tensor(uint8)|
|||11|**T** = tensor(float)|
|||[6, 10]|**T** = tensor(float)|
|Col2Im|*in* input:**T**<br> *in* image_shape:**tensor(int64)**<br> *in* block_shape:**tensor(int64)**<br> *out* output:**T**|18+|**T** = tensor(float)|
|Compress|*in* input:**T**<br> *in* condition:**T1**<br> *out* output:**T**|11+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(bool)|
|||[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(bool)|
|Concat|*in* inputs:**T**<br> *out* concat_result:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|

View file

@ -830,6 +830,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceSumSquare);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceSumSquare);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Col2Im);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseAnd);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseAnd);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseAnd);
@ -2163,6 +2164,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
ReduceSumSquare)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double,
ReduceSumSquare)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, Col2Im)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, BitwiseAnd)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int16_t, BitwiseAnd)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, BitwiseAnd)>,

View file

@ -0,0 +1,113 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cpu/tensor/col2im.h"
#include "core/util/math.h"
#include "core/util/math_cpuonly.h"
namespace onnxruntime {
// math::Col2im and math::Col2imNd only support float data type
ONNX_CPU_OPERATOR_KERNEL(
Col2Im,
18,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Col2Im<float>);
template <typename T>
Status Col2Im<T>::Compute(OpKernelContext* context) const {
const auto* col_tensor = context->Input<Tensor>(0);
const auto* image_shape = context->Input<Tensor>(1);
const auto* kernel_shape = context->Input<Tensor>(2);
size_t image_dim_number = onnxruntime::narrow<size_t>(image_shape->Shape().Size());
TensorShapeVector dilations;
if (dilations_.empty()) {
dilations.resize(image_dim_number, 1);
} else {
ORT_ENFORCE(dilations_.size() == image_dim_number, "size of 'dilations' attribute, if provided, should equal to the number of image dimmensions.");
dilations = dilations_;
}
TensorShapeVector pads;
if (pads_.empty()) {
pads.resize(image_dim_number * 2, 0);
} else {
ORT_ENFORCE(pads_.size() == 2 * image_dim_number, "size of 'pads' attribute, if provided, should equal to twice the number of image dimmensions.");
pads = pads_;
}
TensorShapeVector strides;
if (strides_.empty()) {
strides.resize(image_dim_number, 1);
} else {
ORT_ENFORCE(strides_.size() == image_dim_number, "size of 'strides' attribute, if provided, should equal to the number of image dimmensions.");
strides = strides_;
}
int64_t image_shape_size = 1;
int64_t kernel_shape_size = 1;
TensorShapeVector adjusted_kernel_shape_dims;
auto image_dims = image_shape->Data<int64_t>();
auto kernel_dims = kernel_shape->Data<int64_t>();
for (size_t i = 0; i < image_dim_number; ++i) {
image_shape_size *= image_dims[i];
kernel_shape_size *= kernel_dims[i];
adjusted_kernel_shape_dims.push_back(dilations[i] * (kernel_dims[i] - 1) + 1);
}
TensorShape col_shape = col_tensor->Shape();
const auto N = col_shape[0];
const int64_t C = col_shape[1] / kernel_shape_size;
const int64_t col_stride = C * image_shape_size;
TensorShape adjusted_kernel_shape(adjusted_kernel_shape_dims);
const int64_t col_data_stride = col_shape.SizeFromDimension(1);
TensorShapeVector batched_image_shape_dims, adjusted_image_shape_dims;
batched_image_shape_dims.insert(batched_image_shape_dims.begin(), {N, C});
for (size_t i = 0; i < image_dim_number; ++i) {
batched_image_shape_dims.push_back(image_dims[i]);
adjusted_image_shape_dims.push_back(image_dims[i] - adjusted_kernel_shape[i] + 1);
}
TensorShape batched_image_shape(batched_image_shape_dims);
T* image_data = context->Output(0, batched_image_shape)->template MutableData<T>();
const T* col_data = col_tensor->template Data<T>();
for (auto image_id = 0; image_id < N; ++image_id) {
if (image_dim_number == 2) {
math::Col2im<T, CPUMathUtil, StorageOrder::NCHW>(
col_data + image_id * col_data_stride,
C,
image_dims[0],
image_dims[1],
kernel_dims[0],
kernel_dims[1],
dilations[0],
dilations[1],
pads[0],
pads[1],
pads[2],
pads[3],
strides[0],
strides[1],
image_data + image_id * col_stride,
&CPUMathUtil::Instance());
} else {
math::Col2imNd<T, CPUMathUtil, StorageOrder::NCHW>(
col_data + image_id * col_data_stride,
image_dims,
adjusted_image_shape_dims.data(),
kernel_shape_size * C,
image_shape_size * C,
adjusted_kernel_shape.GetDims().data(),
strides.data(),
dilations.data(),
pads.data(),
image_dim_number,
image_data + image_id * col_stride,
&CPUMathUtil::Instance());
}
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,30 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/framework/op_kernel.h"
namespace onnxruntime {
template <typename T>
class Col2Im final : public OpKernel {
public:
explicit Col2Im(const OpKernelInfo& info) : OpKernel(info) {
if (!info.GetAttrs("strides", strides_).IsOK())
ORT_ENFORCE(strides_.empty());
if (!info.GetAttrs("dilations", dilations_).IsOK())
ORT_ENFORCE(dilations_.empty());
if (!info.GetAttrs("pads", pads_).IsOK())
ORT_ENFORCE(pads_.empty());
}
Status Compute(OpKernelContext* context) const override;
private:
TensorShapeVector pads_;
TensorShapeVector dilations_;
TensorShapeVector strides_;
};
} // namespace onnxruntime

View file

@ -702,6 +702,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
{"test_scatternd_add", "Opset 16 not supported yet."},
{"test_scatternd_multiply", "Opset 16 not supported yet."},
{"test_scatter_elements_with_duplicate_indices", "Opset 16 not supported yet."},
{"col2im_pads", "onnx 18 test data error."},
#if defined(DISABLE_OPTIONAL_TYPE)
{"test_optional_get_element", "Optional type not supported in this build flavor."},

View file

@ -0,0 +1,169 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <stdexcept>
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
#include "core/util/math.h"
namespace onnxruntime {
namespace test {
namespace {
template <typename T>
std::vector<T> TransposeSerializedVector(std::vector<T>& input, size_t N, size_t C, size_t H, size_t W) {
size_t input_size = input.size();
if (input_size == 0) {
throw std::runtime_error("Invalid input");
}
std::vector<T> trans_vec(input);
for (size_t n = 0; n < N; ++n)
for (size_t c = 0; c < C; ++c)
for (size_t h = 0; h < H; ++h)
for (size_t w = 0; w < W; ++w)
trans_vec[n * (C * H * W) + c * (H * W) + (h + H * w)] =
input[n * (C * H * W) + c * (H * W) + (w + W * h)];
return trans_vec;
}
} // namespace
TEST(Col2ImOpTest, Simple4dNCHW) {
OpTester test("Col2Im", 18);
test.AddAttribute("strides", std::vector<int64_t>{1, 1});
test.AddAttribute("dilations", std::vector<int64_t>{1, 1});
test.AddAttribute("pads", std::vector<int64_t>{0, 0, 0, 0});
std::vector<float> input(25);
std::vector<float> output(25);
std::iota(output.begin(), output.end(), 1.0f);
input = TransposeSerializedVector(output, 1, 1, 5, 5);
test.AddInput<float>("input", {1, 5, 5}, input);
test.AddInput<int64_t>("image_shape", {2}, std::vector<int64_t>{5, 5});
test.AddInput<int64_t>("block_shape", {2}, std::vector<int64_t>{1, 5});
test.AddOutput<float>("output", {1, 1, 5, 5}, output);
test.Run();
}
TEST(Col2ImOpTest, With2Images3channelsNonSquare4dNCHW) {
OpTester test("Col2Im", 18);
test.AddAttribute("strides", std::vector<int64_t>{1, 1});
test.AddAttribute("dilations", std::vector<int64_t>{1, 1});
test.AddAttribute("pads", std::vector<int64_t>{0, 0, 0, 0});
std::vector<float> input(120);
std::vector<float> output(120);
std::iota(output.begin(), output.end(), 1.0f);
input = TransposeSerializedVector(output, 2, 3, 4, 5);
test.AddInput<float>("input", {2, 15, 4}, input);
test.AddInput<int64_t>("image_shape", {2}, std::vector<int64_t>{4, 5});
test.AddInput<int64_t>("block_shape", {2}, std::vector<int64_t>{1, 5});
test.AddOutput<float>("output", {2, 3, 4, 5}, output);
test.Run();
}
TEST(Col2ImOpTest, With2Images2channelsNonSquareDilationPadStride4dNCHW) {
OpTester test("Col2Im", 18);
test.AddAttribute("strides", std::vector<int64_t>{2, 2});
test.AddAttribute("dilations", std::vector<int64_t>{2, 2});
test.AddAttribute("pads", std::vector<int64_t>{2, 2, 2, 2});
std::vector<float> input{0., 0., 0., 0., 0., 1., 3., 5., 0., 11., 13., 15., 0., 0., 0., 0.,
0., 0., 0., 0., 1., 3., 5., 0., 11., 13., 15., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 21., 23., 25., 0., 31., 33., 35., 0., 0., 0., 0.,
0., 0., 0., 0., 21., 23., 25., 0., 31., 33., 35., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 41., 43., 45., 0., 51., 53., 55., 0., 0., 0., 0.,
0., 0., 0., 0., 41., 43., 45., 0., 51., 53., 55., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 61., 63., 65., 0., 71., 73., 75., 0., 0., 0., 0.,
0., 0., 0., 0., 61., 63., 65., 0., 71., 73., 75., 0., 0., 0., 0., 0.};
std::vector<float> output{2., 0., 6., 0., 10.,
0., 0., 0., 0., 0.,
22., 0., 26., 0., 30.,
0., 0., 0., 0., 0.,
42., 0., 46., 0., 50.,
0., 0., 0., 0., 0.,
62., 0., 66., 0., 70.,
0., 0., 0., 0., 0.,
82., 0., 86., 0., 90.,
0., 0., 0., 0., 0.,
102., 0., 106., 0., 110.,
0., 0., 0., 0., 0.,
122., 0., 126., 0., 130.,
0., 0., 0., 0., 0.,
142., 0., 146., 0., 150.,
0., 0., 0., 0., 0.};
test.AddInput<float>("input", {2, 4, 16}, input);
test.AddInput<int64_t>("image_shape", {2}, std::vector<int64_t>{4, 5});
test.AddInput<int64_t>("block_shape", {2}, std::vector<int64_t>{1, 2});
test.AddOutput<float>("output", {2, 2, 4, 5}, output);
test.Run();
}
TEST(Col2ImOpTest, With3channels4dNCHW) {
OpTester test("Col2Im", 18);
test.AddAttribute("strides", std::vector<int64_t>{1, 1});
test.AddAttribute("dilations", std::vector<int64_t>{1, 1});
test.AddAttribute("pads", std::vector<int64_t>{0, 0, 0, 0});
std::vector<float> input(75);
std::vector<float> output(75);
std::iota(output.begin(), output.end(), 1.0f);
input = TransposeSerializedVector(output, 1, 3, 5, 5);
test.AddInput<float>("input", {1, 15, 5}, input);
test.AddInput<int64_t>("image_shape", {2}, std::vector<int64_t>{5, 5});
test.AddInput<int64_t>("block_shape", {2}, std::vector<int64_t>{1, 5});
test.AddOutput<float>("output", {1, 3, 5, 5}, output);
test.Run();
}
TEST(Col2ImOpTest, With2Images3channels4dNCHW) {
OpTester test("Col2Im", 18);
test.AddAttribute("strides", std::vector<int64_t>{1, 1});
test.AddAttribute("dilations", std::vector<int64_t>{1, 1});
test.AddAttribute("pads", std::vector<int64_t>{0, 0, 0, 0});
std::vector<float> input(150);
std::vector<float> output(150);
std::iota(output.begin(), output.end(), 1.0f);
input = TransposeSerializedVector(output, 2, 3, 5, 5);
test.AddInput<float>("input", {2, 15, 5}, input);
test.AddInput<int64_t>("image_shape", {2}, std::vector<int64_t>{5, 5});
test.AddInput<int64_t>("block_shape", {2}, std::vector<int64_t>{1, 5});
test.AddOutput<float>("output", {2, 3, 5, 5}, output);
test.Run();
}
TEST(Col2ImOpTest, Simple5dNCHWD) {
OpTester test("Col2Im", 18);
test.AddAttribute("strides", std::vector<int64_t>{1, 1, 1});
test.AddAttribute("dilations", std::vector<int64_t>{1, 1, 1});
test.AddAttribute("pads", std::vector<int64_t>{0, 0, 0, 0, 0, 0});
std::vector<float> input(25);
std::vector<float> output(25);
std::iota(output.begin(), output.end(), 1.0f);
input = TransposeSerializedVector(output, 1, 1, 5, 5);
test.AddInput<float>("input", {1, 5, 5}, input);
test.AddInput<int64_t>("image_shape", {3}, std::vector<int64_t>{1, 5, 5});
test.AddInput<int64_t>("block_shape", {3}, std::vector<int64_t>{1, 1, 5});
test.AddOutput<float>("output", {1, 1, 1, 5, 5}, output);
test.Run();
}
} // namespace test
} // namespace onnxruntime

View file

@ -102,6 +102,7 @@
"^test_if_opt",
"^test_loop16_seq_none",
"^test_identity_opt",
"^test_col2im_pads*", // remove this when using ONNX with this: https://github.com/onnx/onnx/pull/4769
// Following tests are for opset 16 ops and are not yet implemented in ORT
"^test_roialign_aligned_*",
//GPU failures
@ -118,7 +119,6 @@
"^test_roialign_aligned_*",
"^test_clip_default_int8_max_expanded_cpu",
"^test_clip_default_int8_min_expanded_cpu",
"^test_col2im_*",
"^test_softplus_example_expanded_cpu",
"^test_softplus_expanded_cpu",
"^test_split_*",