From e3ea3e8c124d03cf018d8d33a7350826adf52e2b Mon Sep 17 00:00:00 2001 From: Yangqing Jia Date: Mon, 23 Jan 2017 09:44:23 -0800 Subject: [PATCH] MKL convolution operator Summary: Closes https://github.com/caffe2/caffe2/pull/102 Differential Revision: D4448886 Pulled By: Yangqing fbshipit-source-id: 914d11cd79107895a9755154df3526fcf71a31ea --- CMakeLists.txt | 13 ++ caffe2/CMakeLists.txt | 10 -- caffe2/core/flags.h | 28 +++- caffe2/mkl/operators/conv_op.cc | 134 ++++++++++++++++++ caffe2/mkl/operators/relu_op.cc | 5 +- caffe2/python/mkl_basic_test.py | 40 ------ .../python/operator_test/mkl_conv_op_test.py | 51 +++++++ ...l_ops_test.py => mkl_packed_fc_op_test.py} | 2 +- caffe2/python/operator_test/mkl_speed_test.py | 80 +++++++++++ caffe2/python/test_util.py | 1 - caffe2/utils/mkl/mkl_memory.cc | 8 ++ caffe2/utils/mkl/mkl_memory.h | 98 +++++++++++-- cmake/Utils.cmake | 4 +- 13 files changed, 402 insertions(+), 72 deletions(-) create mode 100644 caffe2/mkl/operators/conv_op.cc delete mode 100644 caffe2/python/mkl_basic_test.py create mode 100644 caffe2/python/operator_test/mkl_conv_op_test.py rename caffe2/python/operator_test/{mkl_ops_test.py => mkl_packed_fc_op_test.py} (96%) create mode 100644 caffe2/python/operator_test/mkl_speed_test.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 5e1e06ba602..5b8b1f68aaa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -99,6 +99,19 @@ if (BUILD_PYTHON) message(STATUS "Automatically generating missing __init__.py files.") caffe_autogen_init_py_files() + # Create a custom target that copies all python files. + file(GLOB_RECURSE PYTHON_SRCS RELATIVE ${PROJECT_SOURCE_DIR} + "${PROJECT_SOURCE_DIR}/caffe2/*.py") + add_custom_target(python_copy_files ALL) + foreach(python_src ${PYTHON_SRCS}) + get_filename_component(dir ${python_src} DIRECTORY) + add_custom_command( + TARGET python_copy_files PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${PROJECT_SOURCE_DIR}/${python_src} ${CMAKE_BINARY_DIR}/${dir}) + # file(COPY ${python_src} DESTINATION ${CMAKE_BINARY_DIR}/caffe2/${dir}) + endforeach() + # Install commands # Pick up static python files install(DIRECTORY ${CMAKE_BINARY_DIR}/caffe2 DESTINATION ${CMAKE_INSTALL_PREFIX} diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index e0e3fce90ef..85733544cf4 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -161,13 +161,3 @@ foreach(binary_src ${Caffe2_ALL_BINARY_SRCS}) install(TARGETS ${bin_name} DESTINATION ${CMAKE_INSTALL_PREFIX}/binaries) endforeach() - -# ---[ Python files -if (BUILD_PYTHON) - file(GLOB_RECURSE PYTHON_SRCS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} *.py) - foreach(python_src ${PYTHON_SRCS}) - get_filename_component(dir ${python_src} DIRECTORY) - file(COPY ${python_src} DESTINATION ${CMAKE_BINARY_DIR}/caffe2/${dir}) - endforeach() -endif() - diff --git a/caffe2/core/flags.h b/caffe2/core/flags.h index c5c324a09e0..9e5f8227e84 100644 --- a/caffe2/core/flags.h +++ b/caffe2/core/flags.h @@ -1,12 +1,24 @@ +/** + * @file flags.h + * @brief Commandline flags support for Caffe2. + * + * This is a portable commandline flags tool for caffe2, so we can optionally + * choose to use gflags or a lightweighted custom implementation if gflags is + * not possible on a certain platform. If you have gflags installed, set the + * macro CAFFE2_USE_GFLAGS will seamlessly route everything to gflags. + * + * To define a flag foo of type bool default to true, do the following in the + * *global* namespace: + * CAFFE2_DEFINE_bool(foo, true, "An example."); + * + * To use it in another .cc file, you can use CAFFE2_DECLARE_* as follows: + * CAFFE2_DECLARE_bool(foo); + * + * In both cases, you can then access the flag via caffe2::FLAGS_foo. + */ + #ifndef CAFFE2_CORE_FLAGS_H_ #define CAFFE2_CORE_FLAGS_H_ -// A lightweighted commandline flags tool for caffe2, so we do not need to rely -// on gflags. If you have gflags installed, set the macro CAFFE2_USE_GFLAGS will -// seamlessly route everything to gflags. - -#ifdef CAFFE2_USE_GFLAGS -#include -#endif #include "caffe2/core/registry.h" @@ -44,6 +56,8 @@ bool CommandLineFlagsHasBeenParsed(); #ifdef CAFFE2_USE_GFLAGS +#include + #define CAFFE2_GFLAGS_DEF_WRAPPER(type, name, default_value, help_str) \ DEFINE_##type(name, default_value, help_str); \ namespace caffe2 { \ diff --git a/caffe2/mkl/operators/conv_op.cc b/caffe2/mkl/operators/conv_op.cc new file mode 100644 index 00000000000..2fc3123834a --- /dev/null +++ b/caffe2/mkl/operators/conv_op.cc @@ -0,0 +1,134 @@ +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/operators/conv_pool_op_base.h" +#include "caffe2/utils/mkl_utils.h" + +#ifdef CAFFE2_HAS_MKL_DNN + +namespace caffe2 { +namespace mkl { + +template +class MKLConvOp final : public ConvPoolOpBase { + public: + USE_CONV_POOL_BASE_FUNCTIONS(MKLContext); + MKLConvOp(const OperatorDef& operator_def, Workspace* ws) + : ConvPoolOpBase(operator_def, ws) { + OPERATOR_NEEDS_FEATURE( + dilation_h_ == 1 && dilation_w_ == 1, "Dilation not supported."); + OPERATOR_NEEDS_FEATURE( + pad_l_ == pad_r_ && pad_t_ == pad_b_, "Uneven padding not supported."); + OPERATOR_NEEDS_FEATURE( + order_ == StorageOrder::NCHW, "Only NCHW order supported."); + OPERATOR_NEEDS_FEATURE( + group_ == 1, "Group convolution not supported yet."); + } + ~MKLConvOp() {} + + // TODO(jiayq): support double if needed. + bool RunOnDeviceWithOrderNCHW() override { + auto& X = OperatorBase::Input>(INPUT); + auto& filter = OperatorBase::Input>(FILTER); + auto& bias = OperatorBase::Input>(BIAS); + MKLMemory* Y = OperatorBase::Output>(0); + CAFFE_ENFORCE(4 == X.ndim()); + const int N = X.dim32(0), C = X.dim32(1), H = X.dim32(2), W = X.dim32(3); + CAFFE_ENFORCE(4 == filter.ndim()); + const int M = filter.dim32(0); + + if (cached_input_dims_ != X.dims() || + cached_filter_dims_ != filter.dims()) { + cached_input_dims_ = X.dims(); + cached_filter_dims_ = filter.dims(); + + CAFFE_ENFORCE( + C == filter.dim32(1), + "Convolution op: # of input channels ", + C, + " is not equal to kernel channels:", + filter.dim32(1)); + CAFFE_ENFORCE(filter.dim32(2) == kernel_h_); + CAFFE_ENFORCE(filter.dim32(3) == kernel_w_); + CAFFE_ENFORCE(bias.ndim() == 1); + CAFFE_ENFORCE(bias.dim32(0) == M); + + size_t dimension = 4; + size_t bdata_sizes[4] = {W, H, C, N}; + // We will utilize the SetOutputSize() function int he base class + // with dummy TensorCPU input and output to calculate the sizes. + TensorCPU dummy_input(X.dims()); + TensorCPU dummy_output; + ConvPoolOpBase::SetOutputSize( + dummy_input, &dummy_output, M); + size_t tdata_sizes[4] = { + dummy_output.dim(3), dummy_output.dim(2), + dummy_output.dim(1), dummy_output.dim(0)}; + size_t fdata_sizes[4] = {kernel_w_, kernel_h_, C, M}; + size_t strides[2] = {stride_w_, stride_h_}; + int pads[2] = {-pad_l_, -pad_t_}; + + primitive_.Reset( + dnnConvolutionCreateForwardBias, + nullptr, + dnnAlgorithmConvolutionDirect, + dimension, + bdata_sizes, + tdata_sizes, + fdata_sizes, + strides, + pads, + dnnBorderZeros); + Y->Reset(dummy_output.dims(), primitive_, dnnResourceDst); + buffer_.Reset(dummy_output.dims(), primitive_, dnnResourceDst, true); + + input_layout_.Reset(primitive_, dnnResourceSrc); + filter_layout_.Reset(primitive_, dnnResourceFilter); + bias_layout_.Reset(primitive_, dnnResourceBias); + } + + // Try to share from the output: this allows us to avoid unnecessary copy + // operations, if the output is already allocated and is having the same + // layout as the buffer has. + buffer_.ShareFrom(*Y); + std::shared_ptr X_view = X.View( + input_layout_, primitive_, dnnResourceSrc); + std::shared_ptr filter_view = filter.View( + filter_layout_, primitive_, dnnResourceFilter); + std::shared_ptr bias_view = bias.View( + bias_layout_, primitive_, dnnResourceBias); + resources_[dnnResourceSrc] = X_view.get(); + resources_[dnnResourceFilter] = filter_view.get(); + resources_[dnnResourceBias] = bias_view.get(); + resources_[dnnResourceDst] = buffer_.buffer(); + + MKLDNN_SAFE_CALL(mkl::dnnExecute(primitive_, resources_)); + buffer_.CopyTo(Y, primitive_, dnnResourceDst); + return true; + } + + bool RunOnDeviceWithOrderNHWC() override { + CAFFE_NOT_IMPLEMENTED; + } + + private: + // Input: X, W, b + // Output: Y + vector cached_input_dims_; + vector cached_filter_dims_; + PrimitiveWrapper primitive_; + LayoutWrapper input_layout_; + LayoutWrapper filter_layout_; + LayoutWrapper bias_layout_; + MKLMemory buffer_; + void* resources_[dnnResourceNumber] = {0}; + INPUT_TAGS(INPUT, FILTER, BIAS); +}; + +} // namespace mkl + + +REGISTER_MKL_OPERATOR(Conv, mkl::MKLConvOp); + +} // namespace caffe2 + +#endif // CAFFE2_HAS_MKL_DNN diff --git a/caffe2/mkl/operators/relu_op.cc b/caffe2/mkl/operators/relu_op.cc index 35b848fe882..581830b4aa7 100644 --- a/caffe2/mkl/operators/relu_op.cc +++ b/caffe2/mkl/operators/relu_op.cc @@ -22,8 +22,9 @@ class MKLReluOp : public MKLOperator { Y->Reset(X.dims(), primitive_, dnnResourceDst); buffer_.Reset(X.dims(), primitive_, dnnResourceDst, true); } - // Try to share from the output: this will save a copy if the output is - // already allocated and is having the same layout as the buffer has. + // Try to share from the output: this allows us to avoid unnecessary copy + // operations, if the output is already allocated and is having the same + // layout as the buffer has. buffer_.ShareFrom(*Y); resources_[dnnResourceSrc] = X.buffer(); resources_[dnnResourceDst] = buffer_.buffer(); diff --git a/caffe2/python/mkl_basic_test.py b/caffe2/python/mkl_basic_test.py deleted file mode 100644 index f5dbacb85ed..00000000000 --- a/caffe2/python/mkl_basic_test.py +++ /dev/null @@ -1,40 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals -import unittest - -import numpy as np -from caffe2.proto import caffe2_pb2 -from caffe2.python import cnn, core, workspace, test_util - -@unittest.skipIf(not workspace.C.has_mkldnn, "Skipping as we do not have mkldnn.") -class TestMKLBasic(test_util.TestCase): - def testReLUConsistencyWithCPU(self): - X = np.random.randn(128, 4096).astype(np.float32) - mkl_do = core.DeviceOption(caffe2_pb2.MKLDNN) - # Makes sure that feed works. - workspace.FeedBlob("X", X) - workspace.FeedBlob("X_mkl", X, device_option=mkl_do) - model = cnn.CNNModelHelper() - # Makes sure that we can run relu. - model.Relu("X", "Y") - model.Relu("X_mkl", "Y_mkl", device_option=mkl_do) - workspace.CreateNet(model.net) - workspace.RunNet(model.net) - # makes sure that the results are good. - np.testing.assert_allclose( - workspace.FetchBlob("Y"), - workspace.FetchBlob("Y_mkl"), - atol=1e-10, - rtol=1e-10) - runtime = workspace.BenchmarkNet(model.net.Proto().name, 1, 10, True) - # The returned runtime is the time of - # [whole_net, cpu_op, mkl_op] - # so we will assume that the MKL one runs faster than the CPU one. - self.assertTrue(runtime[1] >= runtime[2]) - print("CPU runtime {}, MKL runtime {}.".format(runtime[1], runtime[2])) - - -if __name__ == '__main__': - unittest.main() diff --git a/caffe2/python/operator_test/mkl_conv_op_test.py b/caffe2/python/operator_test/mkl_conv_op_test.py new file mode 100644 index 00000000000..ac75673fb17 --- /dev/null +++ b/caffe2/python/operator_test/mkl_conv_op_test.py @@ -0,0 +1,51 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import unittest +import hypothesis.strategies as st +from hypothesis import given, settings +import numpy as np +from caffe2.python import core, workspace +import caffe2.python.hypothesis_test_util as hu +import caffe2.python.mkl_test_util as mu + + +@unittest.skipIf(not workspace.C.has_mkldnn, + "Skipping as we do not have mkldnn.") +class MKLConvTest(hu.HypothesisTestCase): + @given(stride=st.integers(1, 3), + pad=st.integers(0, 3), + kernel=st.integers(3, 5), + size=st.integers(8, 8), + input_channels=st.integers(1, 3), + output_channels=st.integers(1, 3), + batch_size=st.integers(1, 3), + **mu.gcs) + @settings(max_examples=2, timeout=100) + def test_mkl_convolution(self, stride, pad, kernel, size, + input_channels, output_channels, + batch_size, gc, dc): + op = core.CreateOperator( + "Conv", + ["X", "w", "b"], + ["Y"], + stride=stride, + pad=pad, + kernel=kernel, + ) + X = np.random.rand( + batch_size, input_channels, size, size).astype(np.float32) - 0.5 + w = np.random.rand( + output_channels, input_channels, kernel, kernel) \ + .astype(np.float32) - 0.5 + b = np.random.rand(output_channels).astype(np.float32) - 0.5 + + inputs = [X, w, b] + self.assertDeviceChecks(dc, op, inputs, [0]) + + +if __name__ == "__main__": + import unittest + unittest.main() diff --git a/caffe2/python/operator_test/mkl_ops_test.py b/caffe2/python/operator_test/mkl_packed_fc_op_test.py similarity index 96% rename from caffe2/python/operator_test/mkl_ops_test.py rename to caffe2/python/operator_test/mkl_packed_fc_op_test.py index c88b83e8172..59546d3891e 100644 --- a/caffe2/python/operator_test/mkl_ops_test.py +++ b/caffe2/python/operator_test/mkl_packed_fc_op_test.py @@ -67,7 +67,7 @@ class PackedFCTest(hu.HypothesisTestCase): def ref(X, W, b): output_axes = list(X.shape[:axis]) + [N] return ( - np.dot(X.reshape(X.size / K, K), W.T).reshape(output_axes) + b,) + np.dot(X.reshape(int(X.size / K), K), W.T).reshape(output_axes) + b,) self.assertReferenceChecks(gc, op, [X, W, b], ref) diff --git a/caffe2/python/operator_test/mkl_speed_test.py b/caffe2/python/operator_test/mkl_speed_test.py new file mode 100644 index 00000000000..4034705580d --- /dev/null +++ b/caffe2/python/operator_test/mkl_speed_test.py @@ -0,0 +1,80 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +import unittest + +import numpy as np +from caffe2.proto import caffe2_pb2 +from caffe2.python import cnn, core, workspace, test_util + + +@unittest.skipIf(not workspace.C.has_mkldnn, "Skipping as we do not have mkldnn.") +class TestMKLBasic(test_util.TestCase): + def testReLUSpeed(self): + X = np.random.randn(128, 4096).astype(np.float32) + mkl_do = core.DeviceOption(caffe2_pb2.MKLDNN) + # Makes sure that feed works. + workspace.FeedBlob("X", X) + workspace.FeedBlob("X_mkl", X, device_option=mkl_do) + net = core.Net("test") + # Makes sure that we can run relu. + net.Relu("X", "Y") + net.Relu("X_mkl", "Y_mkl", device_option=mkl_do) + workspace.CreateNet(net) + workspace.RunNet(net) + # makes sure that the results are good. + np.testing.assert_allclose( + workspace.FetchBlob("Y"), + workspace.FetchBlob("Y_mkl"), + atol=1e-10, + rtol=1e-10) + runtime = workspace.BenchmarkNet(net.Proto().name, 1, 100, True) + + # The returned runtime is the time of + # [whole_net, cpu_op, mkl_op] + # so we will assume that the MKL one runs faster than the CPU one. + + # Note(Yangqing): in fact, it seems that in optimized mode, this is + # not always guaranteed - MKL runs slower than the Eigen vectorized + # version, so I am turning this assertion off. + #self.assertTrue(runtime[1] >= runtime[2]) + + print("Relu CPU runtime {}, MKL runtime {}.".format(runtime[1], runtime[2])) + + + def testConvSpeed(self): + # We randomly select a shape to test the speed. Intentionally we + # test a batch size of 1 since this may be the most frequent use + # case for MKL during deployment time. + X = np.random.rand(1, 256, 27, 27).astype(np.float32) - 0.5 + W = np.random.rand(192, 256, 3, 3).astype(np.float32) - 0.5 + b = np.random.rand(192).astype(np.float32) - 0.5 + mkl_do = core.DeviceOption(caffe2_pb2.MKLDNN) + # Makes sure that feed works. + workspace.FeedBlob("X", X) + workspace.FeedBlob("W", W) + workspace.FeedBlob("b", b) + workspace.FeedBlob("X_mkl", X, device_option=mkl_do) + workspace.FeedBlob("W_mkl", W, device_option=mkl_do) + workspace.FeedBlob("b_mkl", b, device_option=mkl_do) + net = core.Net("test") + # Makes sure that we can run relu. + net.Conv(["X", "W", "b"], "Y", pad=1, stride=1, kernel=3) + net.Conv(["X_mkl", "W_mkl", "b_mkl"], "Y_mkl", + pad=1, stride=1, kernel=3, device_option=mkl_do) + workspace.CreateNet(net) + workspace.RunNet(net) + # makes sure that the results are good. + np.testing.assert_allclose( + workspace.FetchBlob("Y"), + workspace.FetchBlob("Y_mkl"), + atol=1e-2, + rtol=1e-2) + runtime = workspace.BenchmarkNet(net.Proto().name, 1, 100, True) + + print("Conv CPU runtime {}, MKL runtime {}.".format(runtime[1], runtime[2])) + + +if __name__ == '__main__': + unittest.main() diff --git a/caffe2/python/test_util.py b/caffe2/python/test_util.py index c1ce8b34bc7..86a7a09ac71 100644 --- a/caffe2/python/test_util.py +++ b/caffe2/python/test_util.py @@ -20,7 +20,6 @@ class TestCase(unittest.TestCase): workspace.GlobalInit([ 'caffe2', '--caffe2_log_level=0', - '--caffe2_omp_num_threads=1', ]) def setUp(self): diff --git a/caffe2/utils/mkl/mkl_memory.cc b/caffe2/utils/mkl/mkl_memory.cc index e17e5fd3bc2..b8b9166b770 100644 --- a/caffe2/utils/mkl/mkl_memory.cc +++ b/caffe2/utils/mkl/mkl_memory.cc @@ -2,6 +2,14 @@ #ifdef CAFFE2_HAS_MKL_DNN +CAFFE2_DEFINE_bool( + caffe2_mkl_implicit_layout_change, false, + "Controls the behavior when we call View() on an MKLMemory: if it is set " + "true, then the View() function will actually change the underlying " + "storage. If it is set false, an implicit copy is triggered but the " + "original storage is not affected." + ); + namespace caffe2 { CAFFE_KNOWN_TYPE(mkl::MKLMemory); diff --git a/caffe2/utils/mkl/mkl_memory.h b/caffe2/utils/mkl/mkl_memory.h index 03ee99a3965..443a43bf62f 100644 --- a/caffe2/utils/mkl/mkl_memory.h +++ b/caffe2/utils/mkl/mkl_memory.h @@ -3,10 +3,18 @@ #include #include +#include #include "caffe2/core/tensor.h" // for TIndex +#include "caffe2/core/flags.h" // for TIndex #include "caffe2/utils/mkl/mkl_dnn_cppwrapper.h" +// A global boolean variable that controls the behavior when we call View() on +// an MKLMemory: if it is set true, then the View() function will actually +// change the underlying storage. If it is set false, an implicit copy is +// triggered but the original storage is not affected. +CAFFE2_DECLARE_bool(caffe2_mkl_implicit_layout_change); + namespace caffe2 { namespace mkl { @@ -177,6 +185,12 @@ class MKLMemory { convert_out_.Reset(dnnConversionCreate, layout_, user_layout_); share_mem_if_possible_ = share_mem_if_possible; layout_is_user_layout_ = dnnLayoutCompare(layout_, user_layout_); + VLOG(2) << "layout is user layout? " << layout_is_user_layout_; + if (!share_mem_if_possible_) { + // If we are not going to share memory, we will simply allocate + // memory upfront. + buffer(); + } } // Initialize an MKLMemory, with the given dimension assuming a C-contiguous @@ -209,6 +223,12 @@ class MKLMemory { convert_out_.Reset(dnnConversionCreate, layout_, user_layout_); share_mem_if_possible_ = share_mem_if_possible; layout_is_user_layout_ = dnnLayoutCompare(layout_, user_layout_); + VLOG(2) << "layout is user layout? " << layout_is_user_layout_; + if (!share_mem_if_possible_) { + // If we are not going to share memory, we will simply allocate + // memory upfront. + buffer(); + } } // Destructs the MKLMemory. @@ -216,8 +236,10 @@ class MKLMemory { void CopyFrom(const void* ptr) { if (share_mem_if_possible_ && layout_is_user_layout_) { + VLOG(2) << "Sharing underlying memory and skip copy."; buffer_.reset(const_cast(ptr), [](void*) -> void {}); } else { + VLOG(2) << "Copying external content."; MKLDNN_SAFE_CALL(dnnConversionExecute( convert_in_, const_cast(ptr), buffer())); } @@ -261,9 +283,15 @@ class MKLMemory { bool ShareFrom(const MKLMemory& other) { if (share_mem_if_possible_ && dnnLayoutCompare(other.layout_, layout_)) { + VLOG(2) << "Sharing underlying memory."; buffer_ = other.buffer_; + if (!buffer_.get()) { + VLOG(2) << "Warning: the source MKLMemory has no content yet, so the " + "sharing actually has no effect."; + } return true; } else { + VLOG(2) << "Not sharing underlying memory."; return false; } } @@ -271,16 +299,21 @@ class MKLMemory { void CopyTo(void* ptr) const { if (buffer_.get() == ptr) { // This is already mapping to the same memory region. Skip copy. + VLOG(2) << "CopyTo does not need actual copying, as we are sharing " + "memory with the output."; return; } CAFFE_ENFORCE( buffer_.get(), "Canot copy out from an uninitialized MKLMemory."); + VLOG(2) << "Copy to external memory."; MKLDNN_SAFE_CALL(dnnConversionExecute(convert_out_, buffer_.get(), ptr)); } void CopyTo(TensorCPU* tensor) const { if (buffer_.get() == tensor->mutable_data()) { // This is already mapping to the same memory region. Skip copy. + VLOG(2) << "CopyTo does not need actual copying, as we are sharing " + "memory with the output."; return; } tensor->Resize(dims_); @@ -295,7 +328,8 @@ class MKLMemory { const dnnPrimitive_t primitive = nullptr, const dnnResourceType_t type = dnnResourceNumber) { if (buffer_.get() == other->buffer_.get()) { - VLOG(1) << "We are sharing memory with the output, skipping copy."; + VLOG(2) << "CopyTo does not need actual copying, as we are sharing " + "memory with the output."; // This is already mapping to the same memory region. Skip copy. return; } @@ -304,13 +338,13 @@ class MKLMemory { // TODO(jiayq): if primitive creation is a big overhead and we will be // consistently copying stuff with fixed src and dst layouts, consider // making a cache for the primitive below. - VLOG(1) << "Trying direct copy."; + VLOG(2) << "CopyTo requires copying. Performing direct copy."; PrimitiveWrapper convert( dnnConversionCreate, layout_, other->layout_); if (dnnPrimitive_t(convert) == nullptr || dnnConversionExecute(convert, buffer_.get(), other->buffer()) != E_SUCCESS) { - VLOG(1) << "Direct copy failed, will need to allocate output."; + VLOG(2) << "Direct copy failed, will need to allocate output."; // If CopyTo directly did not succeed, it could be because the target // MKLMemory is not having the right layout. In this case we will reset // the target and then do another copy. @@ -348,6 +382,22 @@ class MKLMemory { return dims_; } + inline const int ndim() const { return dims_.size(); } + + inline int dim32(const int i) const { + CAFFE_ENFORCE_LT(dims_.at(i), std::numeric_limits::max()); + return static_cast(dims_[i]); + } + + /** + * Returns the i-th dimension of the tensor. Note that the passed in index + * must be between 0 (inclusive) and the number of dimensions, otherwise + * this function will produce a fatal message. + */ + inline TIndex dim(const int i) const { + return dims_.at(i); + } + inline const LayoutWrapper& layout() const { return layout_; } @@ -355,19 +405,43 @@ class MKLMemory { // Returns a view of the content. We mark this function const, but be noted // that the returned std::shared_ptr is not const protected - user discretion // is recommended for correctness. - std::shared_ptr View(dnnLayout_t layout_wanted) const { - if (dnnLayoutCompare(layout_wanted, layout_)) { + std::shared_ptr View( + dnnLayout_t layout_wanted, + dnnPrimitive_t primitive, + dnnResourceType_t type) const { + std::lock_guard lock(buffer_lock_); + if (dnnLayoutCompare(layout_wanted, layout_)) { // If they are the same, return the original content. + VLOG(2) << "Creating a view without the need of copying."; return std::shared_ptr(buffer_); } else { void* temp_buffer; + VLOG(2) << "Creating a view with copying."; MKLDNN_SAFE_CALL(dnnAllocateBuffer(&temp_buffer, layout_wanted)); PrimitiveWrapper convert( dnnConversionCreate, layout_, layout_wanted); - MKLDNN_SAFE_CALL(dnnConversionExecute(convert, buffer_, temp_buffer)); - return std::shared_ptr(temp_buffer, [](void* ptr) -> void { - MKLDNN_CHECK(dnnReleaseBuffer(ptr)); - }); + MKLDNN_SAFE_CALL(dnnConversionExecute( + convert, buffer_.get(), temp_buffer)); + if (FLAGS_caffe2_mkl_implicit_layout_change) { + VLOG(2) << "Implicit layout change set. " + "Changing the underlying storage."; + // We will need to call Reset to set up all the member variables. + // This is not thread safe, so we might want to double check if this + // makes sense in actual use cases. + const_cast*>(this)->Reset( + dims_, primitive, type, share_mem_if_possible_); + CAFFE_ENFORCE(dnnLayoutCompare(layout_wanted, layout_), + "You passed in a target layout that is not " + "generated by the given primitive and type."); + buffer_.reset(temp_buffer, [](void* ptr) -> void { + MKLDNN_CHECK(dnnReleaseBuffer(ptr)); + }); + return std::shared_ptr(buffer_); + } else { + return std::shared_ptr(temp_buffer, [](void* ptr) -> void { + MKLDNN_CHECK(dnnReleaseBuffer(ptr)); + }); + } } } @@ -375,7 +449,11 @@ class MKLMemory { bool share_mem_if_possible_; bool layout_is_user_layout_; // The internal buffer in the specific dnn layout. - std::shared_ptr buffer_; + // It is marked mutable but any modification in a const function should + // be accompanied by the buffer lock, see the View() function. + mutable std::shared_ptr buffer_; + // A mutex to control the access of buffer in the View() function. + mutable std::mutex buffer_lock_; // The dimensions in the same order as Caffe2 does. This is used to // interface with C2. vector dims_; diff --git a/cmake/Utils.cmake b/cmake/Utils.cmake index b75dd26f2d5..c47feadd6a8 100644 --- a/cmake/Utils.cmake +++ b/cmake/Utils.cmake @@ -394,7 +394,8 @@ endfunction() # Helper function to automatically generate __init__.py files where python # sources reside but there are no __init__.py present. function(caffe_autogen_init_py_files) - file(GLOB_RECURSE all_python_files RELATIVE ${CMAKE_BINARY_DIR} "${CMAKE_BINARY_DIR}/*.py") + file(GLOB_RECURSE all_python_files RELATIVE ${PROJECT_SOURCE_DIR} + "${PROJECT_SOURCE_DIR}/caffe2/*.py") set(python_paths_need_init_py) foreach(python_file ${all_python_files}) get_filename_component(python_path ${python_file} PATH) @@ -408,6 +409,7 @@ function(caffe_autogen_init_py_files) list(REMOVE_DUPLICATES python_paths_need_init_py) # Since the _pb2.py files are yet to be created, we will need to manually # add them to the list. + list(APPEND python_paths_need_init_py ${CMAKE_BINARY_DIR}/caffe) list(APPEND python_paths_need_init_py ${CMAKE_BINARY_DIR}/caffe/proto) list(APPEND python_paths_need_init_py ${CMAKE_BINARY_DIR}/caffe2/proto)