mkldnn conv 1d, 3d support. (#130)

This commit is contained in:
jywu-msft 2018-12-07 14:48:19 -08:00 committed by GitHub
parent a09a3d3aa5
commit ed5abc8b94
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 102 additions and 77 deletions

View file

@ -13,21 +13,21 @@ namespace onnxruntime {
namespace mkl_dnn {
namespace {
// Struct which encapsulates parameters for MKLDNN Conv2d primitive.
struct Conv2dParams {
mkldnn::memory::dims& src_dims;
mkldnn::memory::dims& filter_dims;
mkldnn::memory::dims& bias_dims;
mkldnn::memory::dims& dst_dims;
mkldnn::memory::dims& strides;
mkldnn::memory::dims& dilations;
mkldnn::memory::dims& padding_left;
mkldnn::memory::dims& padding_right;
// Struct which encapsulates parameters for MKLDNN Conv primitive.
struct ConvParams {
const mkldnn::memory::dims& src_dims;
const mkldnn::memory::dims& filter_dims;
const mkldnn::memory::dims& bias_dims;
const mkldnn::memory::dims& dst_dims;
const mkldnn::memory::dims& strides;
const mkldnn::memory::dims& dilations;
const mkldnn::memory::dims& padding_left;
const mkldnn::memory::dims& padding_right;
Conv2dParams(mkldnn::memory::dims& src_dims, mkldnn::memory::dims& filter_dims,
mkldnn::memory::dims& bias_dims, mkldnn::memory::dims& dst_dims,
mkldnn::memory::dims& strides, mkldnn::memory::dims& dilations,
mkldnn::memory::dims& padding_left, mkldnn::memory::dims& padding_right)
ConvParams(const mkldnn::memory::dims& src_dims, const mkldnn::memory::dims& filter_dims,
const mkldnn::memory::dims& bias_dims, const mkldnn::memory::dims& dst_dims,
const mkldnn::memory::dims& strides, const mkldnn::memory::dims& dilations,
const mkldnn::memory::dims& padding_left, const mkldnn::memory::dims& padding_right)
: src_dims(src_dims),
filter_dims(filter_dims),
bias_dims(bias_dims),
@ -37,11 +37,11 @@ struct Conv2dParams {
padding_left(padding_left),
padding_right(padding_right) {}
// Used as the key for Conv2d Primitive Reuse Pool.
// Used as the key for Conv Primitive Reuse Pool.
std::string ToString() const {
std::string key;
key.reserve(128);
key.append("conv2d_");
key.append("conv_");
AddDimsToKey(key, src_dims);
AddDimsToKey(key, filter_dims);
AddDimsToKey(key, bias_dims);
@ -55,9 +55,9 @@ struct Conv2dParams {
};
template <typename T>
class Conv2dPrimitive : public PrimitiveBase {
class ConvPrimitive : public PrimitiveBase {
public:
explicit Conv2dPrimitive(const Conv2dParams& params)
explicit ConvPrimitive(const ConvParams& params)
: cpu_engine_(GetEngine()) {
context_.stream.reset(new mkldnn::stream(mkldnn::stream::kind::eager));
if (context_.conv_fwd == nullptr) {
@ -65,7 +65,7 @@ class Conv2dPrimitive : public PrimitiveBase {
}
}
~Conv2dPrimitive() = default;
~ConvPrimitive() = default;
void Compute(const T* src_data, const T* filter_data,
const T* dst_data, const T* bias_data = nullptr) {
@ -107,7 +107,7 @@ class Conv2dPrimitive : public PrimitiveBase {
}
private:
struct Conv2dContext {
struct ConvContext {
mkldnn::memory::format src_fmt;
mkldnn::memory::format filter_fmt;
mkldnn::memory::format dst_fmt;
@ -134,7 +134,7 @@ class Conv2dPrimitive : public PrimitiveBase {
std::unique_ptr<mkldnn::stream> stream;
std::vector<mkldnn::primitive> net;
Conv2dContext()
ConvContext()
: src_fmt(mkldnn::memory::format::any),
filter_fmt(mkldnn::memory::format::any),
dst_fmt(mkldnn::memory::format::any),
@ -154,7 +154,7 @@ class Conv2dPrimitive : public PrimitiveBase {
stream(nullptr) {}
};
void Initialize(const Conv2dParams& params) {
void Initialize(const ConvParams& params) {
// Set the memory descriptors to format::any to allow MKLDNN to decide what the optimal memory layout should be
// for the computation given the input params.
context_.src_md.reset(new mkldnn::memory::desc(
@ -222,33 +222,33 @@ class Conv2dPrimitive : public PrimitiveBase {
context_.net.push_back(*context_.conv_fwd);
}
Conv2dContext context_;
ConvContext context_;
mkldnn::engine& cpu_engine_;
};
// Pool which allows for reuse of MKLDNN Conv2d primitives which are expensive to instantiate.
// Pool which allows for reuse of MKLDNN Conv primitives which are expensive to instantiate.
// To address thread safety, the primitives are stored in a map on thread local storage.
template <typename T>
class Conv2dPrimitivePool : public PrimitivePool<T> {
class ConvPrimitivePool : public PrimitivePool<T> {
public:
static Conv2dPrimitive<T>* Get(const Conv2dParams& params) {
Conv2dPrimitive<T>* primitive = dynamic_cast<Conv2dPrimitive<T>*>(
Conv2dPrimitivePool<T>::GetInstance().GetPrimitive(params.ToString()));
static ConvPrimitive<T>* Get(const ConvParams& params) {
ConvPrimitive<T>* primitive = dynamic_cast<ConvPrimitive<T>*>(
ConvPrimitivePool<T>::GetInstance().GetPrimitive(params.ToString()));
if (primitive == nullptr) {
auto conv2d_primitive = std::make_unique<Conv2dPrimitive<T>>(params);
primitive = conv2d_primitive.get();
Conv2dPrimitivePool<T>::GetInstance().SetPrimitive(params.ToString(), std::move(conv2d_primitive));
auto conv_primitive = std::make_unique<ConvPrimitive<T>>(params);
primitive = conv_primitive.get();
ConvPrimitivePool<T>::GetInstance().SetPrimitive(params.ToString(), std::move(conv_primitive));
}
return primitive;
}
private:
Conv2dPrimitivePool() = default;
~Conv2dPrimitivePool() = default;
ConvPrimitivePool() = default;
~ConvPrimitivePool() = default;
static Conv2dPrimitivePool& GetInstance() {
static Conv2dPrimitivePool pool;
static ConvPrimitivePool& GetInstance() {
static ConvPrimitivePool pool;
return pool;
}
};
@ -268,20 +268,20 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
ONNXRUNTIME_RETURN_IF_ERROR(onnxruntime::ConvBase::ValidateInputShape(X, W));
std::vector<int64_t> kernel_shape = onnxruntime::ConvBase::ComputeKernelShape(W->Shape());
const size_t kernel_rank = kernel_shape.size();
// TODO: Support more than 2d kernels
if (kernel_shape.size() != 2) {
if (kernel_rank > 3) {
// Fall Back to CPU implementation.
return onnxruntime::Conv<T>::Compute(context);
}
if (kernel_shape.size() + 2 != W->Shape().NumDimensions()) {
if (kernel_rank + 2 != W->Shape().NumDimensions()) {
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape num_dims is not compatible with W num_dims.",
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
" W: ", W->Shape().ToString().c_str());
}
for (size_t i = 0; i < kernel_shape.size(); ++i) {
for (size_t i = 0; i < kernel_rank; ++i) {
if (kernel_shape[i] != W->Shape()[i + 2]) {
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape is not compatible with W shape.",
" kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(),
@ -291,15 +291,15 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
std::vector<int64_t> pads(onnxruntime::ConvBase::pads_);
if (pads.empty()) {
pads.resize(kernel_shape.size() * 2, 0);
pads.resize(kernel_rank * 2, 0);
}
std::vector<int64_t> dilations(onnxruntime::ConvBase::dilations_);
if (dilations.empty()) {
dilations.resize(kernel_shape.size(), 1);
dilations.resize(kernel_rank, 1);
}
std::vector<int64_t> strides(onnxruntime::ConvBase::strides_);
if (strides.empty()) {
strides.resize(kernel_shape.size(), 1);
strides.resize(kernel_rank, 1);
}
std::vector<int64_t> Y_dims;
@ -314,21 +314,19 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
if (group_mkl == 1) {
filter_dims_mkl.assign(W->Shape().GetDims().begin(), W->Shape().GetDims().end());
} else {
filter_dims_mkl.assign({
group_mkl,
static_cast<int>(W->Shape()[0] / group_mkl),
static_cast<int>(W->Shape()[1]),
static_cast<int>(W->Shape()[2]),
static_cast<int>(W->Shape()[3]),
});
filter_dims_mkl.assign({group_mkl,
static_cast<int>(W->Shape()[0] / group_mkl)});
filter_dims_mkl.insert(filter_dims_mkl.end(), W->Shape().GetDims().begin() + 1, W->Shape().GetDims().end());
}
mkldnn::memory::dims strides_mkl(strides.begin(), strides.end());
mkldnn::memory::dims dilations_mkl(dilations.begin(), dilations.end());
// mkldnn dilations start from 0 so we need to subtract 1 from each dim.
dilations_mkl[0] -= 1;
dilations_mkl[1] -= 1;
mkldnn::memory::dims padding_left_mkl(pads.begin(), pads.begin() + 2);
mkldnn::memory::dims padding_right_mkl(pads.begin() + 2, pads.end());
for (size_t dim = 0; dim < kernel_rank; dim++) {
dilations_mkl[dim] -= 1;
}
mkldnn::memory::dims padding_left_mkl(pads.begin(), pads.begin() + kernel_rank);
mkldnn::memory::dims padding_right_mkl(pads.begin() + kernel_rank, pads.end());
mkldnn::memory::dims dst_dims_mkl(Y_dims.begin(), Y_dims.end());
mkldnn::memory::dims bias_dims_mkl;
if (B != nullptr) {
@ -350,27 +348,54 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
}
try {
Conv2dParams conv2d_params(src_dims_mkl, filter_dims_mkl, bias_dims_mkl,
dst_dims_mkl, strides_mkl, dilations_mkl,
padding_left_mkl, padding_right_mkl);
Conv2dPrimitive<T>* conv2d_primitive = Conv2dPrimitivePool<T>::Get(conv2d_params);
auto conv_fwd_pd = conv2d_primitive->GetPrimitiveDesc();
ConvParams conv_params(src_dims_mkl, filter_dims_mkl, bias_dims_mkl,
dst_dims_mkl, strides_mkl, dilations_mkl,
padding_left_mkl, padding_right_mkl);
ConvPrimitive<T>* conv_primitive = ConvPrimitivePool<T>::Get(conv_params);
auto conv_fwd_pd = conv_primitive->GetPrimitiveDesc();
mkldnn::engine& cpu_engine = GetEngine();
// Per ONNX spec,
// X (src) is NCHW, W (filter) is OIHW/GOIHW, and Y (dst) is NCHW
auto src_md = mkldnn::memory::desc(src_dims_mkl, MklDnnType<T>(), mkldnn::memory::format::nchw);
auto filter_format = group_mkl == 1 ? mkldnn::memory::format::oihw : mkldnn::memory::format::goihw;
auto dst_md = mkldnn::memory::desc(dst_dims_mkl, MklDnnType<T>(), mkldnn::memory::format::nchw);
enum mkldnn::memory::format src_format = mkldnn::memory::format::format_undef;
enum mkldnn::memory::format filter_format = mkldnn::memory::format::format_undef;
enum mkldnn::memory::format dst_format = mkldnn::memory::format::format_undef;
if (kernel_rank == 1) {
src_format = mkldnn::memory::format::ncw;
if (group_mkl == 1) {
filter_format = mkldnn::memory::format::oiw;
} else {
filter_format = mkldnn::memory::format::goiw;
}
dst_format = mkldnn::memory::format::ncw;
} else if (kernel_rank == 2) {
src_format = mkldnn::memory::format::nchw;
if (group_mkl == 1) {
filter_format = mkldnn::memory::format::oihw;
} else {
filter_format = mkldnn::memory::format::goihw;
}
dst_format = mkldnn::memory::format::nchw;
} else {
src_format = mkldnn::memory::format::ncdhw;
if (group_mkl == 1) {
filter_format = mkldnn::memory::format::oidhw;
} else {
filter_format = mkldnn::memory::format::goidhw;
}
dst_format = mkldnn::memory::format::ncdhw;
}
auto src_md = mkldnn::memory::desc(src_dims_mkl, MklDnnType<T>(), src_format);
auto dst_md = mkldnn::memory::desc(dst_dims_mkl, MklDnnType<T>(), dst_format);
// Reorder src memory layout if necessary.
if (src_md.data.format != conv2d_primitive->GetSrcMemoryFormat()) {
if (src_md.data.format != conv_primitive->GetSrcMemoryFormat()) {
auto pd = mkldnn::memory::primitive_desc(src_md, cpu_engine);
mkldnn::memory src = mkldnn::memory(pd, (void*)src_data);
// allocate the size queried from memory primitive desc. it may not match tensor logical size due to
// mkldnn using padding to allow use of blocked format.
src_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, conv2d_primitive->GetSrcSize());
src_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, conv_primitive->GetSrcSize());
mkldnn::memory dst = mkldnn::memory(conv_fwd_pd->src_primitive_desc(), src_reorder_buffer.get());
MemoryReorderParams params(src, dst);
DoReorder<T>(params);
@ -378,7 +403,7 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
}
// Reorder filter memory layout if necessary.
if (filter_format != conv2d_primitive->GetFilterMemoryFormat()) {
if (filter_format != conv_primitive->GetFilterMemoryFormat()) {
auto pd = mkldnn::memory::primitive_desc(mkldnn::memory::desc(filter_dims_mkl,
MklDnnType<T>(),
filter_format),
@ -386,7 +411,7 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
mkldnn::memory src = mkldnn::memory(pd, (void*)filter_data);
// allocate the size queried from memory primitive desc. it may not match tensor logical size due to
// mkldnn using padding to allow use of blocked format.
filter_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, conv2d_primitive->GetFilterSize());
filter_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, conv_primitive->GetFilterSize());
mkldnn::memory dst = mkldnn::memory(conv_fwd_pd->weights_primitive_desc(), filter_reorder_buffer.get());
MemoryReorderParams params(src, dst);
DoReorder<T>(params);
@ -394,17 +419,17 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
}
// Allocate dst buffer if reorder is necessary
if (dst_md.data.format != conv2d_primitive->GetDstMemoryFormat()) {
if (dst_md.data.format != conv_primitive->GetDstMemoryFormat()) {
// allocate the size queried from memory primitive desc. it may not match tensor logical size due to
// mkldnn using padding to allow use of blocked format.
dst_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, conv2d_primitive->GetDstSize());
dst_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, conv_primitive->GetDstSize());
dst_data = static_cast<T*>(dst_reorder_buffer.get());
}
conv2d_primitive->Compute(src_data, filter_data, dst_data, bias_data);
conv_primitive->Compute(src_data, filter_data, dst_data, bias_data);
// Reorder dst memory layout if necessary
if (dst_md.data.format != conv2d_primitive->GetDstMemoryFormat()) {
if (dst_md.data.format != conv_primitive->GetDstMemoryFormat()) {
mkldnn::memory src = mkldnn::memory(conv_fwd_pd->dst_primitive_desc(), (void*)dst_data);
auto pd = mkldnn::memory::primitive_desc(dst_md, cpu_engine);
mkldnn::memory dst = mkldnn::memory(pd, Y->template MutableData<T>());

View file

@ -74,7 +74,7 @@ TEST(ConvTest, Conv1D_1) {
auto expected_vals = {-0.052761781960725784f, 0.11481902748346329f, 0.10833403468132019f, -0.11055534332990646f,
-0.012766072526574135f, 0.07113571465015411f, 0.061429332941770554f};
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); // Conv1d not yet optimized for MKLDNN XP
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
}
// Conv3
@ -106,7 +106,7 @@ TEST(ConvTest, Conv1D_2) {
-0.042245108634233475f, -0.08389100432395935f, -0.2509208619594574f, -0.18825212121009827f,
-0.18779152631759644f, -0.11083387583494186f};
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); // Conv1d not yet optimized for MKLDNN XP
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
}
// Conv1
@ -137,7 +137,7 @@ TEST(ConvTest, Conv1D_Bias) {
vector<int64_t> Y_shape = {2, 1, 4};
auto expected_vals = {0.37892162799835205f, 0.4625728130340576f, 0.4934738576412201f, 0.44801419973373413f,
0.37892162799835205f, 0.2499445676803589f, 0.31682088971138f, 0.32773756980895996f};
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); // Conv1d not yet optimized for MKLDNN XP
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape);
}
// Conv47
@ -396,7 +396,7 @@ TEST(ConvTest, Conv3D_1) {
0.10670476406812668f, -0.054437506943941116f, -0.014473143965005875f,
-0.13092079758644104f, 0.10221172869205475f, -0.1479327529668808f,
-0.011351631954312325f, -0.10867488384246826f, -0.05184098333120346f};
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); // Conv3d not yet optimized for MKLDNN XP
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
}
// Conv22
@ -437,7 +437,7 @@ TEST(ConvTest, Conv3D_2) {
0.0f, 0.09152615070343018f, 0.08054415881633759f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); // Conv3d not yet optimized for MKLDNN XP
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
}
// Conv23
@ -520,7 +520,7 @@ TEST(ConvTest, Conv3D_Bias) {
-0.47542816400527954f, -0.5078460574150085f, -0.4205915927886963f, -0.5584549903869629f,
-0.39770257472991943f, -0.45317384600639343f, -0.5598302483558655f, -0.2542789578437805f,
-0.5359901785850525f, -0.48090484738349915f, -0.38603779673576355f, -0.4991581439971924f};
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); // Conv3d not yet optimized for MKLDNN XP
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape);
}
TEST(ConvTest, Conv2D_group) {