mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
update mkldnn to 0.17.1 and address assumptions related to tensor padding that come with new mkldnn version. (#79)
This commit is contained in:
parent
900e69ceae
commit
47a6992e1b
6 changed files with 78 additions and 162 deletions
11
cmake/external/mkldnn.cmake
vendored
11
cmake/external/mkldnn.cmake
vendored
|
|
@ -1,17 +1,14 @@
|
|||
include (ExternalProject)
|
||||
|
||||
set(MKLDNN_URL https://github.com/intel/mkl-dnn.git)
|
||||
# If MKLDNN_TAG is updated, check if platform.cmake.patch or mkldnn_sgemm.patch needs to be updated.
|
||||
set(MKLDNN_TAG v0.15)
|
||||
# If MKLDNN_TAG is updated, check if platform.cmake.patch needs to be updated.
|
||||
set(MKLDNN_TAG v0.17.1)
|
||||
set(MKLDNN_SOURCE ${CMAKE_CURRENT_BINARY_DIR}/mkl-dnn/src/mkl-dnn/src)
|
||||
set(MKLDNN_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/mkl-dnn/install)
|
||||
set(MKLDNN_LIB_DIR ${MKLDNN_INSTALL}/lib)
|
||||
set(MKLDNN_INCLUDE_DIR ${MKLDNN_INSTALL}/include)
|
||||
|
||||
# patch for mkldnn_sgemm thread safety bug.
|
||||
# it can be removed once a fix is available in a validated mkldnn release version.
|
||||
set(MKLDNN_PATCH_COMMAND1 git apply ${CMAKE_SOURCE_DIR}/patches/mkldnn/mkldnn_sgemm.patch)
|
||||
set(MKLDNN_PATCH_COMMAND2 git apply ${CMAKE_SOURCE_DIR}/patches/mkldnn/platform.cmake.patch)
|
||||
set(MKLDNN_PATCH_COMMAND1 git apply ${CMAKE_SOURCE_DIR}/patches/mkldnn/platform.cmake.patch)
|
||||
# discard prior changes due to patching in mkldnn source to unblock incremental builds.
|
||||
set(MKLDNN_PATCH_DISCARD_COMMAND cd ${MKLDNN_SOURCE} && git checkout -- .)
|
||||
|
||||
|
|
@ -44,7 +41,7 @@ ExternalProject_Add(project_mkldnn
|
|||
PREFIX mkl-dnn
|
||||
GIT_REPOSITORY ${MKLDNN_URL}
|
||||
GIT_TAG ${MKLDNN_TAG}
|
||||
PATCH_COMMAND ${DOWNLOAD_MKLML} COMMAND ${MKLDNN_PATCH_DISCARD_COMMAND} COMMAND ${MKLDNN_PATCH_COMMAND1} COMMAND ${MKLDNN_PATCH_COMMAND2}
|
||||
PATCH_COMMAND ${DOWNLOAD_MKLML} COMMAND ${MKLDNN_PATCH_DISCARD_COMMAND} COMMAND ${MKLDNN_PATCH_COMMAND1}
|
||||
SOURCE_DIR ${MKLDNN_SOURCE}
|
||||
CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,136 +0,0 @@
|
|||
diff --git a/src/cpu/gemm/jit_avx2_gemm_f32.cpp b/src/cpu/gemm/jit_avx2_gemm_f32.cpp
|
||||
index bf03c57..46793e7 100644
|
||||
--- a/src/cpu/gemm/jit_avx2_gemm_f32.cpp
|
||||
+++ b/src/cpu/gemm/jit_avx2_gemm_f32.cpp
|
||||
@@ -2349,13 +2349,18 @@ void jit_avx2_gemm_f32::sgemm(const char *transa, const char *transb,
|
||||
|
||||
nthr_mn = nthr_m * nthr_n;
|
||||
|
||||
- unsigned int volatile *ompstatus = (unsigned int volatile *)ompstatus_;
|
||||
- if (!ompstatus) return;
|
||||
+ unsigned int *ompstatus_ = nullptr;
|
||||
+ unsigned int volatile *ompstatus = nullptr;
|
||||
|
||||
float *c_buffers = NULL;
|
||||
float *ws_buffers = NULL;
|
||||
|
||||
if (nthr_k > 1) {
|
||||
+ ompstatus_ = (unsigned int *)malloc(
|
||||
+ sizeof(unsigned int *) * nthrs_ * CACHE_LINE_SIZE, 64);
|
||||
+ ompstatus = (unsigned int volatile *)ompstatus_;
|
||||
+ assert(ompstatus);
|
||||
+
|
||||
for (int i = 0; i < nthr; i++)
|
||||
ompstatus[i * CACHE_LINE_SIZE] = 0;
|
||||
|
||||
@@ -2486,8 +2491,10 @@ void jit_avx2_gemm_f32::sgemm(const char *transa, const char *transb,
|
||||
}
|
||||
}
|
||||
|
||||
- if (nthr_k > 1)
|
||||
+ if (nthr_k > 1) {
|
||||
free(c_buffers);
|
||||
+ free(ompstatus_);
|
||||
+ }
|
||||
free(ws_buffers);
|
||||
}
|
||||
|
||||
@@ -2513,9 +2520,6 @@ jit_avx2_gemm_f32::jit_avx2_gemm_f32(
|
||||
ker_b0_ = ker_bn_;
|
||||
}
|
||||
nthrs_ = omp_get_max_threads();
|
||||
- ompstatus_ = (unsigned int *)malloc(
|
||||
- sizeof(unsigned int *) * nthrs_ * CACHE_LINE_SIZE, 64);
|
||||
- assert(ompstatus_);
|
||||
}
|
||||
|
||||
jit_avx2_gemm_f32::~jit_avx2_gemm_f32()
|
||||
@@ -2525,7 +2529,6 @@ jit_avx2_gemm_f32::~jit_avx2_gemm_f32()
|
||||
delete ker_b1_;
|
||||
if (beta_ != 0.0 || (beta_ == 0.0 && hasBias_))
|
||||
delete ker_b0_;
|
||||
- free(ompstatus_);
|
||||
}
|
||||
|
||||
}
|
||||
diff --git a/src/cpu/gemm/jit_avx2_gemm_f32.hpp b/src/cpu/gemm/jit_avx2_gemm_f32.hpp
|
||||
index 7adb2a2..ebbbde0 100644
|
||||
--- a/src/cpu/gemm/jit_avx2_gemm_f32.hpp
|
||||
+++ b/src/cpu/gemm/jit_avx2_gemm_f32.hpp
|
||||
@@ -49,7 +49,6 @@ private:
|
||||
bool hasBias_;
|
||||
struct xbyak_gemm;
|
||||
xbyak_gemm *ker_bn_, *ker_b1_, *ker_b0_;
|
||||
- unsigned int *ompstatus_;
|
||||
int nthrs_;
|
||||
};
|
||||
}
|
||||
diff --git a/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp b/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp
|
||||
index 7959195..fca14f4 100644
|
||||
--- a/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp
|
||||
+++ b/src/cpu/gemm/jit_avx512_common_gemm_f32.cpp
|
||||
@@ -1866,14 +1866,18 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb,
|
||||
nthr = nthr_m * nthr_n * nthr_k;
|
||||
|
||||
nthr_mn = nthr_m * nthr_n;
|
||||
-
|
||||
- unsigned int volatile *ompstatus = (unsigned int volatile *)ompstatus_;
|
||||
- if (!ompstatus) return;
|
||||
+
|
||||
+ unsigned int *ompstatus_ = nullptr;
|
||||
+ unsigned int volatile *ompstatus = nullptr;
|
||||
|
||||
float *c_buffers = NULL;
|
||||
float *ws_buffers = NULL;
|
||||
|
||||
if (nthr_k > 1) {
|
||||
+ ompstatus_ = (unsigned int *)malloc(
|
||||
+ sizeof(unsigned int *) * nthrs_ * CACHE_LINE_SIZE, 64);
|
||||
+ ompstatus = (unsigned int volatile *)ompstatus_;
|
||||
+ assert(ompstatus);
|
||||
for (int i = 0; i < nthr; i++)
|
||||
ompstatus[i * CACHE_LINE_SIZE] = 0;
|
||||
|
||||
@@ -2004,8 +2008,10 @@ void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb,
|
||||
}
|
||||
}
|
||||
|
||||
- if (nthr_k > 1)
|
||||
+ if (nthr_k > 1) {
|
||||
free(c_buffers);
|
||||
+ free(ompstatus_);
|
||||
+ }
|
||||
free(ws_buffers);
|
||||
}
|
||||
|
||||
@@ -2032,10 +2038,6 @@ jit_avx512_common_gemm_f32::jit_avx512_common_gemm_f32(
|
||||
}
|
||||
|
||||
nthrs_ = omp_get_max_threads();
|
||||
- ompstatus_ = (unsigned int *)malloc(
|
||||
- sizeof(unsigned int *) * nthrs_ * CACHE_LINE_SIZE, 64);
|
||||
- assert(ompstatus_);
|
||||
-
|
||||
}
|
||||
|
||||
jit_avx512_common_gemm_f32::~jit_avx512_common_gemm_f32()
|
||||
@@ -2045,7 +2047,6 @@ jit_avx512_common_gemm_f32::~jit_avx512_common_gemm_f32()
|
||||
delete ker_b1_;
|
||||
if (beta_ != 0.0 || (beta_ == 0.0 && hasBias_))
|
||||
delete ker_b0_;
|
||||
- free(ompstatus_);
|
||||
}
|
||||
}
|
||||
}
|
||||
diff --git a/src/cpu/gemm/jit_avx512_common_gemm_f32.hpp b/src/cpu/gemm/jit_avx512_common_gemm_f32.hpp
|
||||
index ede1cf9..c057335 100644
|
||||
--- a/src/cpu/gemm/jit_avx512_common_gemm_f32.hpp
|
||||
+++ b/src/cpu/gemm/jit_avx512_common_gemm_f32.hpp
|
||||
@@ -49,7 +49,6 @@ private:
|
||||
bool hasBias_;
|
||||
struct xbyak_gemm;
|
||||
xbyak_gemm *ker_bn_, *ker_b1_, *ker_b0_;
|
||||
- unsigned int *ompstatus_;
|
||||
int nthrs_;
|
||||
};
|
||||
}
|
||||
|
|
@ -1,14 +1,14 @@
|
|||
diff --git a/cmake/platform.cmake b/cmake/platform.cmake
|
||||
index fa51aa7..3d24fdc 100644
|
||||
index 3597970a..805ce63e 100644
|
||||
--- a/cmake/platform.cmake
|
||||
+++ b/cmake/platform.cmake
|
||||
@@ -64,9 +64,6 @@ elseif(UNIX OR APPLE OR MINGW)
|
||||
# unconditionnaly.
|
||||
set(CMAKE_CCXX_FLAGS "${CMAKE_CCXX_FLAGS} -Wno-pass-failed")
|
||||
@@ -107,9 +107,6 @@ elseif(UNIX OR MINGW)
|
||||
append(CMAKE_CCXX_SANITIZER_FLAGS "-g -fno-omit-frame-pointer")
|
||||
endif()
|
||||
elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
||||
- if(NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS 5.0)
|
||||
- set(DEF_ARCH_OPT_FLAGS "-march=native -mtune=native")
|
||||
- endif()
|
||||
if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS 6.0)
|
||||
# suppress warning on assumptions made regarding overflow (#146)
|
||||
set(CMAKE_CCXX_FLAGS "${CMAKE_CCXX_FLAGS} -Wno-strict-overflow")
|
||||
# suppress warning on assumptions made regarding overflow (#146)
|
||||
append(CMAKE_CCXX_NOWARN_FLAGS "-Wno-strict-overflow")
|
||||
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Intel")
|
||||
|
|
|
|||
|
|
@ -96,6 +96,12 @@ class Conv2dPrimitive : public PrimitiveBase {
|
|||
|
||||
mkldnn::memory::format GetDstMemoryFormat() const { return context_.dst_fmt; }
|
||||
|
||||
size_t GetSrcSize() const { return context_.src_size; }
|
||||
|
||||
size_t GetFilterSize() const { return context_.filter_size; }
|
||||
|
||||
size_t GetDstSize() const { return context_.dst_size; }
|
||||
|
||||
mkldnn::convolution_forward::primitive_desc* GetPrimitiveDesc() const {
|
||||
return context_.conv_fwd_pd.get();
|
||||
}
|
||||
|
|
@ -106,6 +112,10 @@ class Conv2dPrimitive : public PrimitiveBase {
|
|||
mkldnn::memory::format filter_fmt;
|
||||
mkldnn::memory::format dst_fmt;
|
||||
|
||||
size_t src_size;
|
||||
size_t filter_size;
|
||||
size_t dst_size;
|
||||
|
||||
std::unique_ptr<mkldnn::memory> src_mem;
|
||||
std::unique_ptr<mkldnn::memory> filter_mem;
|
||||
std::unique_ptr<mkldnn::memory> bias_mem;
|
||||
|
|
@ -128,6 +138,9 @@ class Conv2dPrimitive : public PrimitiveBase {
|
|||
: src_fmt(mkldnn::memory::format::any),
|
||||
filter_fmt(mkldnn::memory::format::any),
|
||||
dst_fmt(mkldnn::memory::format::any),
|
||||
src_size(0),
|
||||
filter_size(0),
|
||||
dst_size(0),
|
||||
src_mem(nullptr),
|
||||
filter_mem(nullptr),
|
||||
bias_mem(nullptr),
|
||||
|
|
@ -180,6 +193,12 @@ class Conv2dPrimitive : public PrimitiveBase {
|
|||
context_.dst_fmt = static_cast<mkldnn::memory::format>(
|
||||
context_.conv_fwd_pd.get()->dst_primitive_desc().desc().data.format);
|
||||
|
||||
context_.src_size = context_.conv_fwd_pd.get()->src_primitive_desc().get_size();
|
||||
|
||||
context_.filter_size = context_.conv_fwd_pd.get()->weights_primitive_desc().get_size();
|
||||
|
||||
context_.dst_size = context_.conv_fwd_pd.get()->dst_primitive_desc().get_size();
|
||||
|
||||
context_.src_mem.reset(
|
||||
new mkldnn::memory(context_.conv_fwd_pd.get()->src_primitive_desc(), nullptr));
|
||||
context_.filter_mem.reset(
|
||||
|
|
@ -188,10 +207,9 @@ class Conv2dPrimitive : public PrimitiveBase {
|
|||
new mkldnn::memory(context_.conv_fwd_pd.get()->dst_primitive_desc(), nullptr));
|
||||
|
||||
if (!params.bias_dims.empty()) {
|
||||
context_.bias_mem.reset(new mkldnn::memory(
|
||||
{{{params.bias_dims}, MklDnnType<T>(), mkldnn::memory::format::x},
|
||||
cpu_engine_},
|
||||
nullptr));
|
||||
context_.bias_mem.reset(
|
||||
new mkldnn::memory(context_.conv_fwd_pd.get()->bias_primitive_desc(), nullptr));
|
||||
|
||||
context_.conv_fwd.reset(new mkldnn::convolution_forward(
|
||||
*context_.conv_fwd_pd, *context_.src_mem, *context_.filter_mem,
|
||||
*context_.bias_mem, *context_.dst_mem));
|
||||
|
|
@ -350,7 +368,9 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
|
|||
if (src_md.data.format != conv2d_primitive->GetSrcMemoryFormat()) {
|
||||
auto pd = mkldnn::memory::primitive_desc(src_md, cpu_engine);
|
||||
mkldnn::memory src = mkldnn::memory(pd, (void*)src_data);
|
||||
src_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, sizeof(T) * X->Shape().Size());
|
||||
// allocate the size queried from memory primitive desc. it may not match tensor logical size due to
|
||||
// mkldnn using padding to allow use of blocked format.
|
||||
src_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, conv2d_primitive->GetSrcSize());
|
||||
mkldnn::memory dst = mkldnn::memory(conv_fwd_pd->src_primitive_desc(), src_reorder_buffer.get());
|
||||
MemoryReorderParams params(src, dst);
|
||||
DoReorder<T>(params);
|
||||
|
|
@ -364,7 +384,9 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
|
|||
filter_format),
|
||||
cpu_engine);
|
||||
mkldnn::memory src = mkldnn::memory(pd, (void*)filter_data);
|
||||
filter_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, sizeof(T) * W->Shape().Size());
|
||||
// allocate the size queried from memory primitive desc. it may not match tensor logical size due to
|
||||
// mkldnn using padding to allow use of blocked format.
|
||||
filter_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, conv2d_primitive->GetFilterSize());
|
||||
mkldnn::memory dst = mkldnn::memory(conv_fwd_pd->weights_primitive_desc(), filter_reorder_buffer.get());
|
||||
MemoryReorderParams params(src, dst);
|
||||
DoReorder<T>(params);
|
||||
|
|
@ -373,7 +395,9 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
|
|||
|
||||
// Allocate dst buffer if reorder is necessary
|
||||
if (dst_md.data.format != conv2d_primitive->GetDstMemoryFormat()) {
|
||||
dst_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, sizeof(T) * Y->Shape().Size());
|
||||
// allocate the size queried from memory primitive desc. it may not match tensor logical size due to
|
||||
// mkldnn using padding to allow use of blocked format.
|
||||
dst_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, conv2d_primitive->GetDstSize());
|
||||
dst_data = static_cast<T*>(dst_reorder_buffer.get());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -74,6 +74,10 @@ class LRNPrimitive : public PrimitiveBase {
|
|||
|
||||
mkldnn::memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
|
||||
|
||||
size_t GetSrcSize() const { return context_.src_size; }
|
||||
|
||||
size_t GetDstSize() const { return context_.dst_size; }
|
||||
|
||||
mkldnn::lrn_forward::primitive_desc* GetPrimitiveDesc() const {
|
||||
return context_.fwd_primitive_desc.get();
|
||||
}
|
||||
|
|
@ -83,6 +87,9 @@ class LRNPrimitive : public PrimitiveBase {
|
|||
mkldnn::memory::format src_fmt;
|
||||
std::unique_ptr<mkldnn::memory::desc> src_md;
|
||||
|
||||
size_t src_size;
|
||||
size_t dst_size;
|
||||
|
||||
std::unique_ptr<mkldnn::memory> src_mem;
|
||||
std::unique_ptr<mkldnn::memory> dst_mem;
|
||||
|
||||
|
|
@ -96,6 +103,8 @@ class LRNPrimitive : public PrimitiveBase {
|
|||
LRNContext()
|
||||
: src_fmt(mkldnn::memory::format::any),
|
||||
src_md(nullptr),
|
||||
src_size(0),
|
||||
dst_size(0),
|
||||
src_mem(nullptr),
|
||||
dst_mem(nullptr),
|
||||
fwd_desc(nullptr),
|
||||
|
|
@ -118,6 +127,9 @@ class LRNPrimitive : public PrimitiveBase {
|
|||
context_.src_fmt = static_cast<mkldnn::memory::format>(
|
||||
context_.fwd_primitive_desc.get()->src_primitive_desc().desc().data.format);
|
||||
|
||||
context_.src_size = context_.fwd_primitive_desc.get()->src_primitive_desc().get_size();
|
||||
context_.dst_size = context_.fwd_primitive_desc.get()->dst_primitive_desc().get_size();
|
||||
|
||||
context_.src_mem.reset(new mkldnn::memory(context_.fwd_primitive_desc.get()->src_primitive_desc(), nullptr));
|
||||
context_.dst_mem.reset(new mkldnn::memory(context_.fwd_primitive_desc.get()->dst_primitive_desc(), nullptr));
|
||||
context_.lrn_fwd.reset(
|
||||
|
|
@ -192,7 +204,9 @@ Status LRN<T>::Compute(OpKernelContext* context) const {
|
|||
if (src_md.data.format != lrn_primitive->GetSrcMemoryFormat()) {
|
||||
auto pd = mkldnn::memory::primitive_desc(src_md, cpu_engine);
|
||||
mkldnn::memory src = mkldnn::memory(pd, (void*)src_data);
|
||||
src_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, sizeof(T) * X->Shape().Size());
|
||||
// allocate the size queried from memory primitive desc. it may not match tensor logical size due to
|
||||
// mkldnn using padding to allow use of blocked format.
|
||||
src_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, lrn_primitive->GetSrcSize());
|
||||
mkldnn::memory dst = mkldnn::memory(fwd_primitive_desc->src_primitive_desc(), src_reorder_buffer.get());
|
||||
MemoryReorderParams params(src, dst);
|
||||
DoReorder<T>(params);
|
||||
|
|
@ -201,7 +215,9 @@ Status LRN<T>::Compute(OpKernelContext* context) const {
|
|||
|
||||
// Allocate dst buffer if reorder is necessary
|
||||
if (src_md.data.format != lrn_primitive->GetSrcMemoryFormat()) {
|
||||
dst_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, sizeof(T) * Y->Shape().Size());
|
||||
// allocate the size queried from memory primitive desc. it may not match tensor logical size due to
|
||||
// mkldnn using padding to allow use of blocked format.
|
||||
dst_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, lrn_primitive->GetDstSize());
|
||||
dst_data = static_cast<T*>(dst_reorder_buffer.get());
|
||||
}
|
||||
|
||||
|
|
@ -222,5 +238,5 @@ Status LRN<T>::Compute(OpKernelContext* context) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace mkl_dnn
|
||||
} // namespace onnxruntime
|
||||
}
|
||||
|
|
|
|||
|
|
@ -103,6 +103,9 @@ class PoolPrimitive : public PrimitiveBase {
|
|||
mkldnn::memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
|
||||
mkldnn::memory::format GetDstMemoryFormat() const { return context_.dst_fmt; }
|
||||
|
||||
size_t GetSrcSize() const { return context_.src_size; }
|
||||
size_t GetDstSize() const { return context_.dst_size; }
|
||||
|
||||
// std::unique_ptr<mkldnn::memory::desc> GetDstMemoryDesc() const { return context_.dst_md; }
|
||||
mkldnn::pooling_forward::primitive_desc* GetPrimitiveDesc() const {
|
||||
return context_.fwd_primitive_desc.get();
|
||||
|
|
@ -113,6 +116,9 @@ class PoolPrimitive : public PrimitiveBase {
|
|||
mkldnn::memory::format src_fmt;
|
||||
mkldnn::memory::format dst_fmt;
|
||||
|
||||
size_t src_size;
|
||||
size_t dst_size;
|
||||
|
||||
std::unique_ptr<mkldnn::memory> src_mem;
|
||||
std::unique_ptr<mkldnn::memory> dst_mem;
|
||||
|
||||
|
|
@ -131,6 +137,8 @@ class PoolPrimitive : public PrimitiveBase {
|
|||
PoolContext()
|
||||
: src_fmt(mkldnn::memory::format::any),
|
||||
dst_fmt(mkldnn::memory::format::any),
|
||||
src_size(0),
|
||||
dst_size(0),
|
||||
src_mem(nullptr),
|
||||
dst_mem(nullptr),
|
||||
fwd_desc(nullptr),
|
||||
|
|
@ -178,6 +186,9 @@ class PoolPrimitive : public PrimitiveBase {
|
|||
context_.dst_fmt = static_cast<mkldnn::memory::format>(
|
||||
context_.fwd_primitive_desc.get()->dst_primitive_desc().desc().data.format);
|
||||
|
||||
context_.src_size = context_.fwd_primitive_desc.get()->src_primitive_desc().get_size();
|
||||
context_.dst_size = context_.fwd_primitive_desc.get()->dst_primitive_desc().get_size();
|
||||
|
||||
context_.src_mem.reset(
|
||||
new mkldnn::memory(context_.fwd_primitive_desc.get()->src_primitive_desc(), nullptr));
|
||||
context_.dst_mem.reset(
|
||||
|
|
@ -288,7 +299,9 @@ Status Pool<T, PoolType>::Compute(OpKernelContext* context) const {
|
|||
if (src_md.data.format != pool_primitive->GetSrcMemoryFormat()) {
|
||||
auto pd = mkldnn::memory::primitive_desc(src_md, cpu_engine);
|
||||
mkldnn::memory src = mkldnn::memory(pd, (void*)src_data);
|
||||
src_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, sizeof(T) * X->Shape().Size());
|
||||
// allocate the size queried from memory primitive desc. it may not match tensor logical size due to
|
||||
// mkldnn using padding to allow use of blocked format.
|
||||
src_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, pool_primitive->GetSrcSize());
|
||||
mkldnn::memory dst = mkldnn::memory(fwd_primitive_desc->src_primitive_desc(), src_reorder_buffer.get());
|
||||
MemoryReorderParams params(src, dst);
|
||||
DoReorder<T>(params);
|
||||
|
|
@ -297,7 +310,9 @@ Status Pool<T, PoolType>::Compute(OpKernelContext* context) const {
|
|||
|
||||
// Allocate dst buffer if reorder is necessary
|
||||
if (dst_md.data.format != pool_primitive->GetDstMemoryFormat()) {
|
||||
dst_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, sizeof(T) * Y->Shape().Size());
|
||||
// allocate the size queried from memory primitive desc. it may not match tensor logical size due to
|
||||
// mkldnn using padding to allow use of blocked format.
|
||||
dst_reorder_buffer = IAllocator::MakeUniquePtr<void>(alloc, pool_primitive->GetDstSize());
|
||||
dst_data = static_cast<T*>(dst_reorder_buffer.get());
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue