Reduce memory footprint of MKL-DNN EP (#1429)

* MKL-DNN EP memory fix patch

* Call default provider for Opset10

* opset 10 fix

* removed email header from patch

* UseSubgraph method refactored
This commit is contained in:
Sreekanth Yalachigere 2019-07-18 22:57:00 -07:00 committed by jywu-msft
parent 887930e6c2
commit f3c74ec3e9
5 changed files with 139 additions and 13 deletions

View file

@ -61,6 +61,7 @@ if (onnxruntime_USE_MKLDNN)
set(MKLDNN_INCLUDE_DIR ${MKLDNN_INSTALL}/include)
if(NOT onnxruntime_BUILD_FOR_NATIVE_MACHINE)
set(MKLDNN_PATCH_COMMAND1 git apply ${CMAKE_SOURCE_DIR}/patches/mkldnn/platform.cmake.patch)
set(MKLDNN_PATCH_COMMAND2 git apply ${CMAKE_SOURCE_DIR}/patches/mkldnn/mem-patch.cmake.patch)
# discard prior changes due to patching in mkldnn source to unblock incremental builds.
set(MKLDNN_PATCH_DISCARD_COMMAND cd ${MKLDNN_SOURCE} && git checkout -- .)
endif()
@ -69,6 +70,7 @@ if (onnxruntime_USE_MKLDNN)
GIT_REPOSITORY ${MKLDNN_URL}
GIT_TAG ${MKLDNN_TAG}
PATCH_COMMAND ${MKLDNN_PATCH_DISCARD_COMMAND} COMMAND ${MKLDNN_PATCH_COMMAND1}
COMMAND ${MKLDNN_PATCH_COMMAND2}
SOURCE_DIR ${MKLDNN_SOURCE}
CMAKE_ARGS -DMKLDNN_PRODUCT_BUILD_MODE=OFF -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL} -DMKLROOT=${MKML_DIR}
)

View file

@ -0,0 +1,107 @@
---
src/cpu/jit_avx2_1x1_convolution.cpp | 6 +++---
src/cpu/jit_avx512_common_1x1_convolution.cpp | 9 ++++-----
src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp | 6 ++++--
src/cpu/jit_uni_1x1_conv_utils.hpp | 3 ++-
4 files changed, 13 insertions(+), 11 deletions(-)
diff --git a/src/cpu/jit_avx2_1x1_convolution.cpp b/src/cpu/jit_avx2_1x1_convolution.cpp
index 46362886..edb2b6fb 100644
--- a/src/cpu/jit_avx2_1x1_convolution.cpp
+++ b/src/cpu/jit_avx2_1x1_convolution.cpp
@@ -50,7 +50,7 @@ void jit_avx2_1x1_convolution_fwd_t::execute_forward() const {
const memory_desc_wrapper weights_d(pd()->weights_pd(0));
const auto &jcp = kernel_->jcp;
- auto rtus_space = scratchpad().get<data_t>(key_conv_rtus_space);
+ auto rtus_space = pd()->rtus_.reduce_src_?scratchpad().get<data_t>(key_conv_rtus_space):NULL;
const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
const int ndims = dst_d.ndims();
@@ -180,7 +180,7 @@ void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data() const {
const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
const auto &jcp = kernel_->jcp;
- auto rtus_space = scratchpad().get<data_t>(key_conv_rtus_space);
+ auto rtus_space = pd()->rtus_.reduce_src_?scratchpad().get<data_t>(key_conv_rtus_space):NULL;
// TODO (Roma): remove this restriction
assert(jcp.stride_w == 1 && jcp.stride_h == 1);
@@ -306,7 +306,7 @@ void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights() const {
const memory_desc_wrapper diff_bias_d(pd()->diff_weights_pd(1));
const auto &jcp = kernel_->jcp;
- auto rtus_space = scratchpad.get<data_t>(key_conv_rtus_space);
+ auto rtus_space = pd()->rtus_.reduce_src_?scratchpad.get<data_t>(key_conv_rtus_space):NULL;
data_t *diff_bias = pd()->wants_padded_bias()
? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in;
diff --git a/src/cpu/jit_avx512_common_1x1_convolution.cpp b/src/cpu/jit_avx512_common_1x1_convolution.cpp
index 6879cd91..6a32aa49 100644
--- a/src/cpu/jit_avx512_common_1x1_convolution.cpp
+++ b/src/cpu/jit_avx512_common_1x1_convolution.cpp
@@ -106,7 +106,7 @@ execute_forward_thr(const int ithr, const int nthr, const src_data_t *src,
const memory_desc_wrapper weights_d(pd()->weights_pd(0));
const auto &jcp = kernel_->jcp;
- auto rtus_space = scratchpad.get<src_data_t>(key_conv_rtus_space);
+ auto rtus_space = pd()->rtus_.reduce_src_?scratchpad.get<src_data_t>(key_conv_rtus_space):NULL;
const int ndims = src_d.ndims();
const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
@@ -301,9 +301,8 @@ void jit_avx512_common_1x1_convolution_bwd_data_t<diff_dst_type, wei_type,
const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
const auto &jcp = kernel_->jcp;
- auto rtus_space = scratchpad().template get<diff_src_data_t>(
- key_conv_rtus_space);
-
+ auto rtus_space = pd()->rtus_.reduce_src_? scratchpad().template get<diff_src_data_t>(key_conv_rtus_space): NULL;
+
const int ndims = diff_src_d.ndims();
// TODO (Roma): remove this restriction
@@ -470,7 +469,7 @@ void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights()
const auto scratchpad = this->scratchpad();
- auto rtus_space = scratchpad.get<data_t>(key_conv_rtus_space);
+ auto rtus_space = pd()->rtus_.reduce_src_?scratchpad.get<data_t>(key_conv_rtus_space):NULL;
data_t *diff_bias = pd()->wants_padded_bias()
? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in;
auto wei_reduction = scratchpad.get<data_t>(key_conv_wei_reduction);
diff --git a/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp b/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp
index de303cd2..ec0c54e7 100644
--- a/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp
+++ b/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp
@@ -100,8 +100,10 @@ void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<src_type, dst_type>
? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
const auto &jcp = kernel_->jcp;
- auto rtus_space = scratchpad.get<src_data_t>(key_conv_rtus_space);
- auto local_scales = scratchpad.get<float>(key_conv_adjusted_scales);
+
+ auto rtus_space = pd()->rtus_.reduce_src_?scratchpad.get<src_data_t>(key_conv_rtus_space):NULL;
+
+ auto local_scales = scratchpad.get<float>(key_conv_adjusted_scales);
const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
diff --git a/src/cpu/jit_uni_1x1_conv_utils.hpp b/src/cpu/jit_uni_1x1_conv_utils.hpp
index a3ed769a..5a0e0635 100644
--- a/src/cpu/jit_uni_1x1_conv_utils.hpp
+++ b/src/cpu/jit_uni_1x1_conv_utils.hpp
@@ -94,7 +94,8 @@ inline void rtus_prepare(conv_pd_t *self, const convolution_desc_t *&conv_d,
template <typename conv_pd_t>
inline void rtus_prepare_space_info(conv_pd_t *self,
memory_tracking::registrar_t &scratchpad) {
- const auto &jcp = self->jcp_;
+ if (!self->rtus_.reduce_src_) return;
+ const auto &jcp = self->jcp_;
const int max_threads = mkldnn_get_max_threads();
const size_t factor = utils::pick_by_prop_kind(self->desc()->prop_kind,
--
2.17.0.windows.1

View file

@ -112,13 +112,14 @@ std::shared_ptr<KernelRegistry> MKLDNNExecutionProvider::GetKernelRegistry() con
}
bool MKLDNNExecutionProvider::UseSubgraph(const onnxruntime::GraphViewer& graph_viewer,
const std::vector<const KernelRegistry*>& kernel_registries,
std::vector<std::unique_ptr<ComputeCapability>>& result) const {
const std::vector<const KernelRegistry*>& kernel_registries) const {
// switch between mkldnn-vanilla and mkldnn-subgraph implementation using
// MKLDNN_SUBGRAPH environment variable
bool use_subgraph = true;
bool FP16_graph = false;
bool mkldnn_nodes_in_the_graph = false;
if (graph_viewer.MaxNodeIndex() > 0) {
int index = 0;
auto node = graph_viewer.GetNode(index);
@ -130,16 +131,27 @@ bool MKLDNNExecutionProvider::UseSubgraph(const onnxruntime::GraphViewer& graph_
FP16_graph = node->InputDefs()[0]->Type()->find("16") != std::string::npos;
}
if (FP16_graph) {
for (auto node_index = 0; node_index < graph_viewer.MaxNodeIndex(); node_index++) {
auto node = graph_viewer.GetNode(node_index);
if (node == nullptr) {
node_index++;
continue;
}
auto op_it = mkldnn_ops_.find(node->OpType());
if (op_it != mkldnn_ops_.end()) {
mkldnn_nodes_in_the_graph = true;
break;
}
}
if (FP16_graph || !mkldnn_nodes_in_the_graph) {
// FP16 not supported yet.
use_subgraph = false;
result = IExecutionProvider::GetCapability(graph_viewer, kernel_registries);
} else {
const char* env = getenv("ORT_MKLDNN_SUBGRAPH");
if (env != nullptr) {
if (atoi(env) == 0) {
use_subgraph = false;
result = IExecutionProvider::GetCapability(graph_viewer, kernel_registries);
}
}
}
@ -209,16 +221,16 @@ std::vector<std::unique_ptr<ComputeCapability>> MKLDNNExecutionProvider::GetCapa
const onnxruntime::GraphViewer& graph_viewer,
const std::vector<const KernelRegistry*>& kernel_registries) const {
ORT_UNUSED_PARAMETER(kernel_registries);
std::vector<std::unique_ptr<ComputeCapability>> result;
// temporary switch to toggle between mkldnn-vanilla and mkldnn-subgraph implementation using
// ORT_MKLDNN_SUBGRAPH environment variable
if (UseSubgraph(graph_viewer, kernel_registries, result) == false) {
return result;
if (UseSubgraph(graph_viewer, kernel_registries) == false) {
return IExecutionProvider::GetCapability(graph_viewer, kernel_registries);
}
LOGS_DEFAULT(INFO) << "Using MKL-DNN Subgraph";
// use sub-graph implementation
std::vector<std::unique_ptr<ComputeCapability>> result;
mkl_dnn::Subgraph::SubgraphVariables sub_var;
std::shared_ptr<mkl_dnn::Subgraph> subgraph_ptr;
@ -243,6 +255,12 @@ std::vector<std::unique_ptr<ComputeCapability>> MKLDNNExecutionProvider::GetCapa
if (IsDimensionSupported(node) == false) {
node_index++;
if (subgraph_ptr->mkldnn_nodes.size() > 0) {
CreateMetaDef(graph_viewer, subgraph_attributes, subgraph_ptr, sub_var, result);
subgraph_ptr.reset(new mkl_dnn::Subgraph(graph_name));
subgraph_attributes.clear();
output_to_source_node_map.clear();
}
continue;
}
@ -436,7 +454,7 @@ Status MKLDNNExecutionProvider::Compile(const std::vector<onnxruntime::Node*>& f
compute_info.compute_func = [](FunctionState state, const OrtCustomOpApi* api, OrtKernelContext* context) {
onnxruntime::mkl_dnn::MkldnnFuncKernel<float>* custom_op = reinterpret_cast<mkl_dnn::MkldnnFuncKernel<float>*>(state);
return custom_op->Compute(api, context);
return custom_op->Compute(api, context);
};
node_compute_funcs.push_back(compute_info);

View file

@ -99,8 +99,7 @@ class MKLDNNExecutionProvider : public IExecutionProvider {
}
bool UseSubgraph(const onnxruntime::GraphViewer& graph_viewer,
const std::vector<const KernelRegistry*>& kernel_registries,
std::vector<std::unique_ptr<ComputeCapability>>& result) const;
const std::vector<const KernelRegistry*>& kernel_registries) const;
// Some dimensions are not supported by MKL-DNN
// example: Pool with NumDimensions <= 3 is not supported

View file

@ -238,13 +238,13 @@ class MklDnnConv : public MklDnnKernel {
if (!bias_dims_mkl.empty()) {
fwd_desc_.reset(new mkldnn::convolution_forward::desc(
mkldnn::prop_kind::forward, mkldnn::convolution_direct, *src_md_,
mkldnn::prop_kind::forward_inference, mkldnn::convolution_direct, *src_md_,
*filter_md_, *bias_md_, *primitive_dst_md_,
strides_mkl, dilations_mkl, padding_left_mkl,
padding_right_mkl, mkldnn::padding_kind::zero));
} else {
fwd_desc_.reset(new mkldnn::convolution_forward::desc(
mkldnn::prop_kind::forward, mkldnn::convolution_direct, *src_md_,
mkldnn::prop_kind::forward_inference, mkldnn::convolution_direct, *src_md_,
*filter_md_, *primitive_dst_md_, strides_mkl,
dilations_mkl, padding_left_mkl,
padding_right_mkl, mkldnn::padding_kind::zero));