From ed5abc8b94ac4fb34646e08dcf312cd2010bafaf Mon Sep 17 00:00:00 2001 From: jywu-msft <43355415+jywu-msft@users.noreply.github.com> Date: Fri, 7 Dec 2018 14:48:19 -0800 Subject: [PATCH] mkldnn conv 1d, 3d support. (#130) --- onnxruntime/core/providers/mkldnn/nn/conv.cc | 167 ++++++++++-------- .../test/providers/cpu/nn/conv_op_test.cc | 12 +- 2 files changed, 102 insertions(+), 77 deletions(-) diff --git a/onnxruntime/core/providers/mkldnn/nn/conv.cc b/onnxruntime/core/providers/mkldnn/nn/conv.cc index e978e82e9f..ea0656e900 100644 --- a/onnxruntime/core/providers/mkldnn/nn/conv.cc +++ b/onnxruntime/core/providers/mkldnn/nn/conv.cc @@ -13,21 +13,21 @@ namespace onnxruntime { namespace mkl_dnn { namespace { -// Struct which encapsulates parameters for MKLDNN Conv2d primitive. -struct Conv2dParams { - mkldnn::memory::dims& src_dims; - mkldnn::memory::dims& filter_dims; - mkldnn::memory::dims& bias_dims; - mkldnn::memory::dims& dst_dims; - mkldnn::memory::dims& strides; - mkldnn::memory::dims& dilations; - mkldnn::memory::dims& padding_left; - mkldnn::memory::dims& padding_right; +// Struct which encapsulates parameters for MKLDNN Conv primitive. +struct ConvParams { + const mkldnn::memory::dims& src_dims; + const mkldnn::memory::dims& filter_dims; + const mkldnn::memory::dims& bias_dims; + const mkldnn::memory::dims& dst_dims; + const mkldnn::memory::dims& strides; + const mkldnn::memory::dims& dilations; + const mkldnn::memory::dims& padding_left; + const mkldnn::memory::dims& padding_right; - Conv2dParams(mkldnn::memory::dims& src_dims, mkldnn::memory::dims& filter_dims, - mkldnn::memory::dims& bias_dims, mkldnn::memory::dims& dst_dims, - mkldnn::memory::dims& strides, mkldnn::memory::dims& dilations, - mkldnn::memory::dims& padding_left, mkldnn::memory::dims& padding_right) + ConvParams(const mkldnn::memory::dims& src_dims, const mkldnn::memory::dims& filter_dims, + const mkldnn::memory::dims& bias_dims, const mkldnn::memory::dims& dst_dims, + const mkldnn::memory::dims& strides, const mkldnn::memory::dims& dilations, + const mkldnn::memory::dims& padding_left, const mkldnn::memory::dims& padding_right) : src_dims(src_dims), filter_dims(filter_dims), bias_dims(bias_dims), @@ -37,11 +37,11 @@ struct Conv2dParams { padding_left(padding_left), padding_right(padding_right) {} - // Used as the key for Conv2d Primitive Reuse Pool. + // Used as the key for Conv Primitive Reuse Pool. std::string ToString() const { std::string key; key.reserve(128); - key.append("conv2d_"); + key.append("conv_"); AddDimsToKey(key, src_dims); AddDimsToKey(key, filter_dims); AddDimsToKey(key, bias_dims); @@ -55,9 +55,9 @@ struct Conv2dParams { }; template -class Conv2dPrimitive : public PrimitiveBase { +class ConvPrimitive : public PrimitiveBase { public: - explicit Conv2dPrimitive(const Conv2dParams& params) + explicit ConvPrimitive(const ConvParams& params) : cpu_engine_(GetEngine()) { context_.stream.reset(new mkldnn::stream(mkldnn::stream::kind::eager)); if (context_.conv_fwd == nullptr) { @@ -65,7 +65,7 @@ class Conv2dPrimitive : public PrimitiveBase { } } - ~Conv2dPrimitive() = default; + ~ConvPrimitive() = default; void Compute(const T* src_data, const T* filter_data, const T* dst_data, const T* bias_data = nullptr) { @@ -107,7 +107,7 @@ class Conv2dPrimitive : public PrimitiveBase { } private: - struct Conv2dContext { + struct ConvContext { mkldnn::memory::format src_fmt; mkldnn::memory::format filter_fmt; mkldnn::memory::format dst_fmt; @@ -134,7 +134,7 @@ class Conv2dPrimitive : public PrimitiveBase { std::unique_ptr stream; std::vector net; - Conv2dContext() + ConvContext() : src_fmt(mkldnn::memory::format::any), filter_fmt(mkldnn::memory::format::any), dst_fmt(mkldnn::memory::format::any), @@ -154,7 +154,7 @@ class Conv2dPrimitive : public PrimitiveBase { stream(nullptr) {} }; - void Initialize(const Conv2dParams& params) { + void Initialize(const ConvParams& params) { // Set the memory descriptors to format::any to allow MKLDNN to decide what the optimal memory layout should be // for the computation given the input params. context_.src_md.reset(new mkldnn::memory::desc( @@ -222,33 +222,33 @@ class Conv2dPrimitive : public PrimitiveBase { context_.net.push_back(*context_.conv_fwd); } - Conv2dContext context_; + ConvContext context_; mkldnn::engine& cpu_engine_; }; -// Pool which allows for reuse of MKLDNN Conv2d primitives which are expensive to instantiate. +// Pool which allows for reuse of MKLDNN Conv primitives which are expensive to instantiate. // To address thread safety, the primitives are stored in a map on thread local storage. template -class Conv2dPrimitivePool : public PrimitivePool { +class ConvPrimitivePool : public PrimitivePool { public: - static Conv2dPrimitive* Get(const Conv2dParams& params) { - Conv2dPrimitive* primitive = dynamic_cast*>( - Conv2dPrimitivePool::GetInstance().GetPrimitive(params.ToString())); + static ConvPrimitive* Get(const ConvParams& params) { + ConvPrimitive* primitive = dynamic_cast*>( + ConvPrimitivePool::GetInstance().GetPrimitive(params.ToString())); if (primitive == nullptr) { - auto conv2d_primitive = std::make_unique>(params); - primitive = conv2d_primitive.get(); - Conv2dPrimitivePool::GetInstance().SetPrimitive(params.ToString(), std::move(conv2d_primitive)); + auto conv_primitive = std::make_unique>(params); + primitive = conv_primitive.get(); + ConvPrimitivePool::GetInstance().SetPrimitive(params.ToString(), std::move(conv_primitive)); } return primitive; } private: - Conv2dPrimitivePool() = default; - ~Conv2dPrimitivePool() = default; + ConvPrimitivePool() = default; + ~ConvPrimitivePool() = default; - static Conv2dPrimitivePool& GetInstance() { - static Conv2dPrimitivePool pool; + static ConvPrimitivePool& GetInstance() { + static ConvPrimitivePool pool; return pool; } }; @@ -268,20 +268,20 @@ Status Conv::Compute(OpKernelContext* context) const { ONNXRUNTIME_RETURN_IF_ERROR(onnxruntime::ConvBase::ValidateInputShape(X, W)); std::vector kernel_shape = onnxruntime::ConvBase::ComputeKernelShape(W->Shape()); + const size_t kernel_rank = kernel_shape.size(); - // TODO: Support more than 2d kernels - if (kernel_shape.size() != 2) { + if (kernel_rank > 3) { // Fall Back to CPU implementation. return onnxruntime::Conv::Compute(context); } - if (kernel_shape.size() + 2 != W->Shape().NumDimensions()) { + if (kernel_rank + 2 != W->Shape().NumDimensions()) { return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape num_dims is not compatible with W num_dims.", " kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(), " W: ", W->Shape().ToString().c_str()); } - for (size_t i = 0; i < kernel_shape.size(); ++i) { + for (size_t i = 0; i < kernel_rank; ++i) { if (kernel_shape[i] != W->Shape()[i + 2]) { return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape is not compatible with W shape.", " kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(), @@ -291,15 +291,15 @@ Status Conv::Compute(OpKernelContext* context) const { std::vector pads(onnxruntime::ConvBase::pads_); if (pads.empty()) { - pads.resize(kernel_shape.size() * 2, 0); + pads.resize(kernel_rank * 2, 0); } std::vector dilations(onnxruntime::ConvBase::dilations_); if (dilations.empty()) { - dilations.resize(kernel_shape.size(), 1); + dilations.resize(kernel_rank, 1); } std::vector strides(onnxruntime::ConvBase::strides_); if (strides.empty()) { - strides.resize(kernel_shape.size(), 1); + strides.resize(kernel_rank, 1); } std::vector Y_dims; @@ -314,21 +314,19 @@ Status Conv::Compute(OpKernelContext* context) const { if (group_mkl == 1) { filter_dims_mkl.assign(W->Shape().GetDims().begin(), W->Shape().GetDims().end()); } else { - filter_dims_mkl.assign({ - group_mkl, - static_cast(W->Shape()[0] / group_mkl), - static_cast(W->Shape()[1]), - static_cast(W->Shape()[2]), - static_cast(W->Shape()[3]), - }); + filter_dims_mkl.assign({group_mkl, + static_cast(W->Shape()[0] / group_mkl)}); + filter_dims_mkl.insert(filter_dims_mkl.end(), W->Shape().GetDims().begin() + 1, W->Shape().GetDims().end()); } mkldnn::memory::dims strides_mkl(strides.begin(), strides.end()); mkldnn::memory::dims dilations_mkl(dilations.begin(), dilations.end()); // mkldnn dilations start from 0 so we need to subtract 1 from each dim. - dilations_mkl[0] -= 1; - dilations_mkl[1] -= 1; - mkldnn::memory::dims padding_left_mkl(pads.begin(), pads.begin() + 2); - mkldnn::memory::dims padding_right_mkl(pads.begin() + 2, pads.end()); + for (size_t dim = 0; dim < kernel_rank; dim++) { + dilations_mkl[dim] -= 1; + } + + mkldnn::memory::dims padding_left_mkl(pads.begin(), pads.begin() + kernel_rank); + mkldnn::memory::dims padding_right_mkl(pads.begin() + kernel_rank, pads.end()); mkldnn::memory::dims dst_dims_mkl(Y_dims.begin(), Y_dims.end()); mkldnn::memory::dims bias_dims_mkl; if (B != nullptr) { @@ -350,27 +348,54 @@ Status Conv::Compute(OpKernelContext* context) const { } try { - Conv2dParams conv2d_params(src_dims_mkl, filter_dims_mkl, bias_dims_mkl, - dst_dims_mkl, strides_mkl, dilations_mkl, - padding_left_mkl, padding_right_mkl); - Conv2dPrimitive* conv2d_primitive = Conv2dPrimitivePool::Get(conv2d_params); - auto conv_fwd_pd = conv2d_primitive->GetPrimitiveDesc(); + ConvParams conv_params(src_dims_mkl, filter_dims_mkl, bias_dims_mkl, + dst_dims_mkl, strides_mkl, dilations_mkl, + padding_left_mkl, padding_right_mkl); + ConvPrimitive* conv_primitive = ConvPrimitivePool::Get(conv_params); + auto conv_fwd_pd = conv_primitive->GetPrimitiveDesc(); mkldnn::engine& cpu_engine = GetEngine(); - // Per ONNX spec, - // X (src) is NCHW, W (filter) is OIHW/GOIHW, and Y (dst) is NCHW - auto src_md = mkldnn::memory::desc(src_dims_mkl, MklDnnType(), mkldnn::memory::format::nchw); - auto filter_format = group_mkl == 1 ? mkldnn::memory::format::oihw : mkldnn::memory::format::goihw; - auto dst_md = mkldnn::memory::desc(dst_dims_mkl, MklDnnType(), mkldnn::memory::format::nchw); + enum mkldnn::memory::format src_format = mkldnn::memory::format::format_undef; + enum mkldnn::memory::format filter_format = mkldnn::memory::format::format_undef; + enum mkldnn::memory::format dst_format = mkldnn::memory::format::format_undef; + + if (kernel_rank == 1) { + src_format = mkldnn::memory::format::ncw; + if (group_mkl == 1) { + filter_format = mkldnn::memory::format::oiw; + } else { + filter_format = mkldnn::memory::format::goiw; + } + dst_format = mkldnn::memory::format::ncw; + } else if (kernel_rank == 2) { + src_format = mkldnn::memory::format::nchw; + if (group_mkl == 1) { + filter_format = mkldnn::memory::format::oihw; + } else { + filter_format = mkldnn::memory::format::goihw; + } + dst_format = mkldnn::memory::format::nchw; + } else { + src_format = mkldnn::memory::format::ncdhw; + if (group_mkl == 1) { + filter_format = mkldnn::memory::format::oidhw; + } else { + filter_format = mkldnn::memory::format::goidhw; + } + dst_format = mkldnn::memory::format::ncdhw; + } + + auto src_md = mkldnn::memory::desc(src_dims_mkl, MklDnnType(), src_format); + auto dst_md = mkldnn::memory::desc(dst_dims_mkl, MklDnnType(), dst_format); // Reorder src memory layout if necessary. - if (src_md.data.format != conv2d_primitive->GetSrcMemoryFormat()) { + if (src_md.data.format != conv_primitive->GetSrcMemoryFormat()) { auto pd = mkldnn::memory::primitive_desc(src_md, cpu_engine); mkldnn::memory src = mkldnn::memory(pd, (void*)src_data); // allocate the size queried from memory primitive desc. it may not match tensor logical size due to // mkldnn using padding to allow use of blocked format. - src_reorder_buffer = IAllocator::MakeUniquePtr(alloc, conv2d_primitive->GetSrcSize()); + src_reorder_buffer = IAllocator::MakeUniquePtr(alloc, conv_primitive->GetSrcSize()); mkldnn::memory dst = mkldnn::memory(conv_fwd_pd->src_primitive_desc(), src_reorder_buffer.get()); MemoryReorderParams params(src, dst); DoReorder(params); @@ -378,7 +403,7 @@ Status Conv::Compute(OpKernelContext* context) const { } // Reorder filter memory layout if necessary. - if (filter_format != conv2d_primitive->GetFilterMemoryFormat()) { + if (filter_format != conv_primitive->GetFilterMemoryFormat()) { auto pd = mkldnn::memory::primitive_desc(mkldnn::memory::desc(filter_dims_mkl, MklDnnType(), filter_format), @@ -386,7 +411,7 @@ Status Conv::Compute(OpKernelContext* context) const { mkldnn::memory src = mkldnn::memory(pd, (void*)filter_data); // allocate the size queried from memory primitive desc. it may not match tensor logical size due to // mkldnn using padding to allow use of blocked format. - filter_reorder_buffer = IAllocator::MakeUniquePtr(alloc, conv2d_primitive->GetFilterSize()); + filter_reorder_buffer = IAllocator::MakeUniquePtr(alloc, conv_primitive->GetFilterSize()); mkldnn::memory dst = mkldnn::memory(conv_fwd_pd->weights_primitive_desc(), filter_reorder_buffer.get()); MemoryReorderParams params(src, dst); DoReorder(params); @@ -394,17 +419,17 @@ Status Conv::Compute(OpKernelContext* context) const { } // Allocate dst buffer if reorder is necessary - if (dst_md.data.format != conv2d_primitive->GetDstMemoryFormat()) { + if (dst_md.data.format != conv_primitive->GetDstMemoryFormat()) { // allocate the size queried from memory primitive desc. it may not match tensor logical size due to // mkldnn using padding to allow use of blocked format. - dst_reorder_buffer = IAllocator::MakeUniquePtr(alloc, conv2d_primitive->GetDstSize()); + dst_reorder_buffer = IAllocator::MakeUniquePtr(alloc, conv_primitive->GetDstSize()); dst_data = static_cast(dst_reorder_buffer.get()); } - conv2d_primitive->Compute(src_data, filter_data, dst_data, bias_data); + conv_primitive->Compute(src_data, filter_data, dst_data, bias_data); // Reorder dst memory layout if necessary - if (dst_md.data.format != conv2d_primitive->GetDstMemoryFormat()) { + if (dst_md.data.format != conv_primitive->GetDstMemoryFormat()) { mkldnn::memory src = mkldnn::memory(conv_fwd_pd->dst_primitive_desc(), (void*)dst_data); auto pd = mkldnn::memory::primitive_desc(dst_md, cpu_engine); mkldnn::memory dst = mkldnn::memory(pd, Y->template MutableData()); diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index d7edd4dacd..5c4b1659bd 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -74,7 +74,7 @@ TEST(ConvTest, Conv1D_1) { auto expected_vals = {-0.052761781960725784f, 0.11481902748346329f, 0.10833403468132019f, -0.11055534332990646f, -0.012766072526574135f, 0.07113571465015411f, 0.061429332941770554f}; - TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); // Conv1d not yet optimized for MKLDNN XP + TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } // Conv3 @@ -106,7 +106,7 @@ TEST(ConvTest, Conv1D_2) { -0.042245108634233475f, -0.08389100432395935f, -0.2509208619594574f, -0.18825212121009827f, -0.18779152631759644f, -0.11083387583494186f}; - TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); // Conv1d not yet optimized for MKLDNN XP + TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } // Conv1 @@ -137,7 +137,7 @@ TEST(ConvTest, Conv1D_Bias) { vector Y_shape = {2, 1, 4}; auto expected_vals = {0.37892162799835205f, 0.4625728130340576f, 0.4934738576412201f, 0.44801419973373413f, 0.37892162799835205f, 0.2499445676803589f, 0.31682088971138f, 0.32773756980895996f}; - TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); // Conv1d not yet optimized for MKLDNN XP + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); } // Conv47 @@ -396,7 +396,7 @@ TEST(ConvTest, Conv3D_1) { 0.10670476406812668f, -0.054437506943941116f, -0.014473143965005875f, -0.13092079758644104f, 0.10221172869205475f, -0.1479327529668808f, -0.011351631954312325f, -0.10867488384246826f, -0.05184098333120346f}; - TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); // Conv3d not yet optimized for MKLDNN XP + TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } // Conv22 @@ -437,7 +437,7 @@ TEST(ConvTest, Conv3D_2) { 0.0f, 0.09152615070343018f, 0.08054415881633759f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; - TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); // Conv3d not yet optimized for MKLDNN XP + TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } // Conv23 @@ -520,7 +520,7 @@ TEST(ConvTest, Conv3D_Bias) { -0.47542816400527954f, -0.5078460574150085f, -0.4205915927886963f, -0.5584549903869629f, -0.39770257472991943f, -0.45317384600639343f, -0.5598302483558655f, -0.2542789578437805f, -0.5359901785850525f, -0.48090484738349915f, -0.38603779673576355f, -0.4991581439971924f}; - TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); // Conv3d not yet optimized for MKLDNN XP + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); } TEST(ConvTest, Conv2D_group) {