MaxpoolWithMask (#3831)

This commit is contained in:
Changming Sun 2020-05-05 22:19:35 -07:00 committed by GitHub
parent edaf8a542c
commit f0c9fbc051
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -2,23 +2,211 @@
// Licensed under the MIT License.
/*
* Highly specialized code, only works for TP3 L1
*/
* Highly specialized code, only works for TP3 L1
*/
#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/framework/tensor.h"
#include "core/providers/cpu/nn/pool_base.h"
#include "core/platform/threadpool.h"
namespace onnxruntime {
namespace contrib {
template <typename T>
struct MaxpoolWithMask1DTask final {
const T* X_data;
const int32_t* M_data;
T* Y_data;
int64_t x_step;
int64_t y_step;
int64_t pooled_height;
int64_t stride_h;
int64_t height;
int64_t total_mask_channels;
const std::vector<int64_t>& kernel_shape;
const std::vector<int64_t>& pads;
TensorOpCost Cost() {
double loop_count = static_cast<double>(pooled_height * kernel_shape[0]);
return TensorOpCost{loop_count, loop_count, loop_count};
}
void operator()(std::ptrdiff_t begin, std::ptrdiff_t end) const {
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t c = begin; c < end; ++c) {
operator()(c);
}
}
void operator()(std::ptrdiff_t c) const {
const T* x_d = X_data + c * x_step;
const int32_t* m_d = M_data + (c * x_step) % total_mask_channels;
T* y_d = Y_data + c * y_step;
for (int64_t ph = 0; ph < pooled_height; ++ph) {
int64_t hstart = ph * stride_h - pads[0];
int64_t hend = std::min(hstart + kernel_shape[0], height);
hstart = std::max(hstart, static_cast<int64_t>(0));
T Yh = std::numeric_limits<T>::lowest();
for (int64_t h = hstart; h < hend; ++h) {
if (h >= 0 && m_d[h] == 0)
break; // if mask == 0, stop
if (x_d[h] > Yh) {
Yh = x_d[h];
}
}
y_d[ph] = Yh;
}
}
};
template <typename T>
struct MaxpoolWithMask2DTask final {
const T* X_data;
const int32_t* M_data;
T* Y_data;
int64_t x_step;
int64_t y_step;
int64_t pooled_height;
int64_t pooled_width;
int64_t stride_h;
int64_t stride_w;
int64_t height;
int64_t width;
int64_t total_mask_channels;
const std::vector<int64_t>& kernel_shape;
const std::vector<int64_t>& pads;
TensorOpCost Cost() {
double loop_count = static_cast<double>(pooled_height * kernel_shape[0]);
return TensorOpCost{loop_count, loop_count, loop_count};
}
void operator()(std::ptrdiff_t begin, std::ptrdiff_t end) const {
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t c = begin; c < end; ++c) {
operator()(c);
}
}
void operator()(std::ptrdiff_t c) const {
const T* x_d = X_data + c * x_step;
const int32_t* m_d = M_data + (c * x_step) % total_mask_channels;
T* y_d = Y_data + c * y_step;
for (int64_t ph = 0; ph < pooled_height; ++ph) {
int64_t hstart = ph * stride_h - pads[0];
int64_t hend = std::min(hstart + kernel_shape[0], height);
hstart = std::max(hstart, static_cast<int64_t>(0));
for (int64_t pw = 0; pw < pooled_width; ++pw) {
int64_t wstart = pw * stride_w - pads[1];
int64_t wend = std::min(wstart + kernel_shape[1], width);
wstart = std::max(wstart, static_cast<int64_t>(0));
const int64_t pool_index = ph * pooled_width + pw;
T Yh = std::numeric_limits<T>::lowest();
for (int64_t h = hstart; h < hend; ++h) {
for (int64_t w = wstart; w < wend; ++w) {
const int64_t input_index = h * width + w;
if (input_index > 0 && m_d[input_index] == 0)
break; // if mask == 0, break
if (x_d[input_index] > Yh) {
Yh = x_d[input_index];
}
}
}
y_d[pool_index] = Yh;
}
}
}
};
template <typename T>
struct MaxpoolWithMask3DTask final {
const T* X_data;
const int32_t* M_data;
T* Y_data;
int64_t x_step;
int64_t y_step;
int64_t pooled_height;
int64_t pooled_width;
int64_t pooled_depth;
int64_t stride_h;
int64_t stride_w;
int64_t stride_d;
int64_t height;
int64_t width;
int64_t depth;
int64_t total_mask_channels;
const std::vector<int64_t>& kernel_shape;
const std::vector<int64_t>& pads;
TensorOpCost Cost() {
double loop_count = static_cast<double>(pooled_height * kernel_shape[0]);
return TensorOpCost{loop_count, loop_count, loop_count};
}
void operator()(std::ptrdiff_t begin, std::ptrdiff_t end) const {
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t c = begin; c < end; ++c) {
operator()(c);
}
}
void operator()(std::ptrdiff_t c) const {
const T* x_d = X_data + c * x_step;
const int32_t* m_d = M_data + (c * x_step) % total_mask_channels;
T* y_d = Y_data + c * y_step;
for (int64_t ph = 0; ph < pooled_height; ++ph) {
int64_t hstart = ph * stride_h - pads[0];
int64_t hend = std::min(hstart + kernel_shape[0], height);
hstart = std::max(hstart, static_cast<int64_t>(0));
for (int64_t pw = 0; pw < pooled_width; ++pw) {
int64_t wstart = pw * stride_w - pads[1];
int64_t wend = std::min(wstart + kernel_shape[1], width);
wstart = std::max(wstart, static_cast<int64_t>(0));
for (int64_t pd = 0; pd < pooled_depth; ++pd) {
int64_t dstart = pd * stride_d - pads[2];
int64_t dend = std::min(dstart + kernel_shape[2], depth);
dstart = std::max(dstart, static_cast<int64_t>(0));
const int64_t pool_index = ph * pooled_width * pooled_depth + pw * pooled_depth + pd;
T Yh = std::numeric_limits<T>::lowest();
for (int64_t h = hstart; h < hend; ++h) {
for (int64_t w = wstart; w < wend; ++w) {
for (int64_t d = dstart; d < dend; ++d) {
const int64_t input_index = h * width * depth + w * depth + d;
if (input_index > 0 && m_d[input_index] == 0)
break; // if mask == 0, break
if (x_d[input_index] > Yh) {
Yh = x_d[input_index];
}
}
}
}
y_d[pool_index] = Yh;
}
}
}
}
};
template <typename T>
inline static void RunMaxpoolLoop(concurrency::ThreadPool* tp, std::ptrdiff_t total_channels, T&& task) {
#ifdef _OPENMP
ORT_UNUSED_PARAMETER(tp);
task(0, total_channels);
#else
concurrency::ThreadPool::TryParallelFor(tp, total_channels, task.Cost(), task);
#endif
}
class MaxpoolWithMask : public OpKernel, public PoolBase {
public:
MaxpoolWithMask(const OpKernelInfo& info) : OpKernel(info), PoolBase(info) {}
MaxpoolWithMask(const OpKernelInfo& info) : OpKernel(info), PoolBase(info) {
}
Status Compute(OpKernelContext* context) const override {
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
const Tensor* X = context->Input<Tensor>(0);
const Tensor* M = context->Input<Tensor>(1);
@ -26,8 +214,9 @@ class MaxpoolWithMask : public OpKernel, public PoolBase {
const TensorShape& m_shape = M->Shape();
ORT_RETURN_IF_NOT(x_shape.NumDimensions() >= 3, "Input dimension cannot be less than 3.");
//TODO: fix this checker later
//ONNXRUNTIME_RETURN_IF_NOT((x_shape[2] == m_shape[2]) && (x_shape[3] == m_shape[3]), " Input shape and mask shape mismatch: ", x_shape, " vs ", m_shape);
// TODO: fix this checker later
// ONNXRUNTIME_RETURN_IF_NOT((x_shape[2] == m_shape[2]) && (x_shape[3] == m_shape[3]), " Input shape and mask shape
// mismatch: ", x_shape, " vs ", m_shape);
std::vector<int64_t> pads = pool_attrs_.pads;
std::vector<int64_t> kernel_shape = pool_attrs_.kernel_shape;
@ -54,28 +243,9 @@ class MaxpoolWithMask : public OpKernel, public PoolBase {
int64_t y_step = pooled_height;
const int64_t total_channels = x_shape[0] * channels;
const int64_t total_mask_channels = m_shape[0] * m_shape[1];
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t c = 0; c < total_channels; ++c) {
const float* x_d = X_data + c * x_step;
const int32_t* m_d = M_data + (c * x_step) % total_mask_channels;
float* y_d = Y_data + c * y_step;
for (int64_t ph = 0; ph < pooled_height; ++ph) {
int64_t hstart = ph * stride_h() - pads[0];
int64_t hend = std::min(hstart + kernel_shape[0], height);
hstart = std::max(hstart, static_cast<int64_t>(0));
float Yh = std::numeric_limits<float>::lowest();
for (int64_t h = hstart; h < hend; ++h) {
if (h >= 0 && m_d[h] == 0) break; // if mask == 0, stop
if (x_d[h] > Yh) {
Yh = x_d[h];
}
}
y_d[ph] = Yh;
}
}
RunMaxpoolLoop<MaxpoolWithMask1DTask<float>>(tp, total_channels,
{X_data, M_data, Y_data, x_step, y_step, pooled_height, stride_h(),
height, total_mask_channels, kernel_shape, pads});
break;
}
@ -84,37 +254,10 @@ class MaxpoolWithMask : public OpKernel, public PoolBase {
int64_t y_step = pooled_height * pooled_width;
const int64_t total_channels = x_shape[0] * channels;
const int64_t total_mask_channels = m_shape[0] * m_shape[1];
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t c = 0; c < total_channels; ++c) {
const float* x_d = X_data + c * x_step;
const int32_t* m_d = M_data + (c * x_step) % total_mask_channels;
float* y_d = Y_data + c * y_step;
for (int64_t ph = 0; ph < pooled_height; ++ph) {
int64_t hstart = ph * stride_h() - pads[0];
int64_t hend = std::min(hstart + kernel_shape[0], height);
hstart = std::max(hstart, static_cast<int64_t>(0));
for (int64_t pw = 0; pw < pooled_width; ++pw) {
int64_t wstart = pw * stride_w() - pads[1];
int64_t wend = std::min(wstart + kernel_shape[1], width);
wstart = std::max(wstart, static_cast<int64_t>(0));
const int64_t pool_index = ph * pooled_width + pw;
float Yh = std::numeric_limits<float>::lowest();
for (int64_t h = hstart; h < hend; ++h) {
for (int64_t w = wstart; w < wend; ++w) {
const int64_t input_index = h * width + w;
if (input_index > 0 && m_d[input_index] == 0) break; // if mask == 0, break
if (x_d[input_index] > Yh) {
Yh = x_d[input_index];
}
}
}
y_d[pool_index] = Yh;
}
}
}
RunMaxpoolLoop<MaxpoolWithMask2DTask<float>>(
tp, total_channels,
{X_data, M_data, Y_data, x_step, y_step, pooled_height, pooled_width, stride_h(), stride_w(), height, width,
total_mask_channels, kernel_shape, pads});
break;
}
case 3: {
@ -122,45 +265,10 @@ class MaxpoolWithMask : public OpKernel, public PoolBase {
int64_t y_step = pooled_height * pooled_width * pooled_depth;
const int64_t total_channels = x_shape[0] * channels;
const int64_t total_mask_channels = m_shape[0] * m_shape[1];
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t c = 0; c < total_channels; ++c) {
const float* x_d = X_data + c * x_step;
const int32_t* m_d = M_data + (c * x_step) % total_mask_channels;
float* y_d = Y_data + c * y_step;
for (int64_t ph = 0; ph < pooled_height; ++ph) {
int64_t hstart = ph * stride_h() - pads[0];
int64_t hend = std::min(hstart + kernel_shape[0], height);
hstart = std::max(hstart, static_cast<int64_t>(0));
for (int64_t pw = 0; pw < pooled_width; ++pw) {
int64_t wstart = pw * stride_w() - pads[1];
int64_t wend = std::min(wstart + kernel_shape[1], width);
wstart = std::max(wstart, static_cast<int64_t>(0));
for (int64_t pd = 0; pd < pooled_depth; ++pd) {
int64_t dstart = pd * stride_d() - pads[2];
int64_t dend = std::min(dstart + kernel_shape[2], depth);
dstart = std::max(dstart, static_cast<int64_t>(0));
const int64_t pool_index =
ph * pooled_width * pooled_depth + pw * pooled_depth + pd;
float Yh = std::numeric_limits<float>::lowest();
for (int64_t h = hstart; h < hend; ++h) {
for (int64_t w = wstart; w < wend; ++w) {
for (int64_t d = dstart; d < dend; ++d) {
const int64_t input_index = h * width * depth + w * depth + d;
if (input_index > 0 && m_d[input_index] == 0) break; // if mask == 0, break
if (x_d[input_index] > Yh) {
Yh = x_d[input_index];
}
}
}
}
y_d[pool_index] = Yh;
}
}
}
}
RunMaxpoolLoop<MaxpoolWithMask3DTask<float>>(
tp, total_channels,
{X_data, M_data, Y_data, x_step, y_step, pooled_height, pooled_width, pooled_depth, stride_h(), stride_w(),
stride_d(), height, width, depth, total_mask_channels, kernel_shape, pads});
break;
}
default: