mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add serialization interface for MKLMemory
Summary: This allows us to serialize things between MKLMemory and a TensorProto. Reviewed By: dzhulgakov Differential Revision: D4218044 fbshipit-source-id: 934181493b482cb259c17ff4b17008eac52fd885
This commit is contained in:
parent
e65eeff665
commit
ab3fea540d
8 changed files with 433 additions and 58 deletions
|
|
@ -69,7 +69,14 @@ std::string TensorDeviceTypeName(const int32_t& d) {
|
|||
case MKLDNN:
|
||||
return "TensorMKLDNN";
|
||||
default:
|
||||
CAFFE_THROW("Unknown device: ", d);
|
||||
CAFFE_THROW(
|
||||
"Unknown device: ",
|
||||
d,
|
||||
". If you have recently updated the caffe2.proto file to add a new "
|
||||
"device type, did you forget to update the TensorDeviceTypeName() "
|
||||
"function to reflect such recent changes?");
|
||||
// The below code won't run but is needed to suppress some compiler
|
||||
// warnings.
|
||||
return "";
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -114,6 +114,35 @@ inline Dst dynamic_cast_if_rtti(Src ptr) {
|
|||
#endif
|
||||
}
|
||||
|
||||
// SkipIndices are used in operator_fallback_gpu.h and operator_fallback_mkl.h
|
||||
// as utilty functions that marks input / output indices to skip when we use a
|
||||
// CPU operator as the fallback of GPU/MKL operator option.
|
||||
template <int... values>
|
||||
class SkipIndices {
|
||||
private:
|
||||
template <int V>
|
||||
static inline bool ContainsInternal(const int i) {
|
||||
return (i == V);
|
||||
}
|
||||
template <int First, int Second, int... Rest>
|
||||
static inline bool ContainsInternal(const int i) {
|
||||
return (i == First) && ContainsInternal<Second, Rest...>(i);
|
||||
}
|
||||
|
||||
public:
|
||||
static inline bool Contains(const int i) {
|
||||
return ContainsInternal<values...>(i);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class SkipIndices<> {
|
||||
public:
|
||||
static inline bool Contains(const int i) {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_CORE_COMMON_H_
|
||||
|
|
|
|||
124
caffe2/mkl/mklmemory_serialization.cc
Normal file
124
caffe2/mkl/mklmemory_serialization.cc
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
#include "caffe2/core/blob.h"
|
||||
#include "caffe2/core/blob_serialization.h"
|
||||
#include "caffe2/utils/mkl_utils.h"
|
||||
|
||||
#ifdef CAFFE2_HAS_MKL_DNN
|
||||
|
||||
namespace caffe2 {
|
||||
namespace mkl {
|
||||
/**
|
||||
* @brief MKLMemorySerializer is the serializer for MKLMemory.
|
||||
*
|
||||
* MKLMemorySerializer takes in a blob that contains an MKLMemory, and
|
||||
* serializes it into a TensorProto protocol buffer.
|
||||
*/
|
||||
class MKLMemorySerializer : public BlobSerializerBase {
|
||||
public:
|
||||
MKLMemorySerializer() {}
|
||||
~MKLMemorySerializer() {}
|
||||
|
||||
void Serialize(
|
||||
const Blob& blob,
|
||||
const string& name,
|
||||
SerializationAcceptor acceptor) override {
|
||||
BlobProto blob_proto;
|
||||
blob_proto.set_name(name);
|
||||
blob_proto.set_type(kTensorBlobType);
|
||||
TensorProto* proto = blob_proto.mutable_tensor();
|
||||
auto* device_detail = proto->mutable_device_detail();
|
||||
device_detail->set_device_type(MKLDNN);
|
||||
proto->set_name(name);
|
||||
if (blob.IsType<MKLMemory<float>>()) {
|
||||
const MKLMemory<float>& src = blob.Get<MKLMemory<float>>();
|
||||
CAFFE_ENFORCE(
|
||||
src.buffer(), "Cannot serialize an empty MKLMemory object.");
|
||||
size_t total = 1;
|
||||
for (int i = 0; i < src.dims().size(); ++i) {
|
||||
proto->add_dims(src.dims()[i]);
|
||||
total *= src.dims()[i];
|
||||
}
|
||||
proto->mutable_float_data()->Reserve(total);
|
||||
while (total--) {
|
||||
proto->add_float_data(0);
|
||||
}
|
||||
src.CopyTo(proto->mutable_float_data()->mutable_data());
|
||||
} else if (blob.IsType<MKLMemory<double>>()) {
|
||||
const MKLMemory<double>& src = blob.Get<MKLMemory<double>>();
|
||||
CAFFE_ENFORCE(
|
||||
src.buffer(), "Cannot serialize an empty MKLMemory object.");
|
||||
size_t total = 1;
|
||||
for (int i = 0; i < src.dims().size(); ++i) {
|
||||
proto->add_dims(src.dims()[i]);
|
||||
total *= src.dims()[i];
|
||||
}
|
||||
proto->mutable_double_data()->Reserve(total);
|
||||
while (total--) {
|
||||
proto->add_double_data(0);
|
||||
}
|
||||
src.CopyTo(proto->mutable_double_data()->mutable_data());
|
||||
} else {
|
||||
CAFFE_THROW(
|
||||
"MKLMemory could only be either float or double. "
|
||||
"Encountered unsupported type.");
|
||||
}
|
||||
acceptor(name, blob_proto.SerializeAsString());
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief MKLMemoryDeserializer is the deserializer for TensorProto that has
|
||||
* MKLDNN as its device.
|
||||
*
|
||||
* The device that the deserialized Tensor will live under is determined by the
|
||||
* device_detail field. If you want to specify the device of the deserialized
|
||||
* tensor, change the TensorProto's corresponding fields before calling
|
||||
* Deserialize.
|
||||
*/
|
||||
class MKLMemoryDeserializer : public BlobDeserializerBase {
|
||||
public:
|
||||
bool Deserialize(const BlobProto& blob_proto, Blob* blob) override {
|
||||
const TensorProto& proto = blob_proto.tensor();
|
||||
CAFFE_ENFORCE(
|
||||
proto.data_type() == TensorProto_DataType_FLOAT ||
|
||||
proto.data_type() == TensorProto_DataType_DOUBLE,
|
||||
"MKLMemory only supports either float or double formats.");
|
||||
CAFFE_ENFORCE(
|
||||
!proto.has_segment(), "MKLMemory does not support segment right now.");
|
||||
vector<TIndex> dims;
|
||||
for (const TIndex d : proto.dims()) {
|
||||
dims.push_back(d);
|
||||
}
|
||||
// TODO: right now, every time we do a deserializer we create a new MKL
|
||||
// Memory object. Optionally, we can change that.
|
||||
switch (proto.data_type()) {
|
||||
case TensorProto_DataType_FLOAT: {
|
||||
auto dst = make_unique<MKLMemory<float>>(dims);
|
||||
dst->CopyFrom(proto.float_data().data());
|
||||
blob->Reset(dst.release());
|
||||
break;
|
||||
}
|
||||
case TensorProto_DataType_DOUBLE: {
|
||||
auto dst = make_unique<MKLMemory<double>>(dims);
|
||||
dst->CopyFrom(proto.double_data().data());
|
||||
blob->Reset(dst.release());
|
||||
break;
|
||||
}
|
||||
default:
|
||||
CAFFE_THROW("This should not happen, we guarded things above already.");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mkl
|
||||
|
||||
REGISTER_BLOB_SERIALIZER(
|
||||
(TypeMeta::Id<mkl::MKLMemory<float>>()),
|
||||
mkl::MKLMemorySerializer);
|
||||
REGISTER_BLOB_SERIALIZER(
|
||||
(TypeMeta::Id<mkl::MKLMemory<double>>()),
|
||||
mkl::MKLMemorySerializer);
|
||||
REGISTER_BLOB_DESERIALIZER(TensorMKLDNN, mkl::MKLMemoryDeserializer);
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_HAS_MKL_DNN
|
||||
54
caffe2/mkl/mklmemory_serialization_test.cc
Normal file
54
caffe2/mkl/mklmemory_serialization_test.cc
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
#include "caffe2/core/blob.h"
|
||||
#include "caffe2/core/blob_serialization.h"
|
||||
#include "caffe2/core/common.h"
|
||||
#include "caffe2/utils/mkl_utils.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#ifdef CAFFE2_HAS_MKL_DNN
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
using mkl::MKLMemory;
|
||||
|
||||
TEST(MKLTest, MKLMemorySerialization) {
|
||||
Blob blob;
|
||||
vector<int> shape{2, 3, 4};
|
||||
float data[2 * 3 * 4];
|
||||
for (int i = 0; i < 2 * 3 * 4; ++i) {
|
||||
data[i] = i;
|
||||
}
|
||||
blob.Reset<MKLMemory<float>>(new MKLMemory<float>(shape));
|
||||
MKLMemory<float>* mkl_memory = blob.GetMutable<MKLMemory<float>>();
|
||||
mkl_memory->CopyFrom(data);
|
||||
string serialized = blob.Serialize("test");
|
||||
BlobProto proto;
|
||||
CHECK(proto.ParseFromString(serialized));
|
||||
EXPECT_EQ(proto.name(), "test");
|
||||
EXPECT_EQ(proto.type(), "Tensor");
|
||||
EXPECT_TRUE(proto.has_tensor());
|
||||
const TensorProto& tensor_proto = proto.tensor();
|
||||
EXPECT_EQ(
|
||||
tensor_proto.data_type(), TypeMetaToDataType(TypeMeta::Make<float>()));
|
||||
EXPECT_EQ(tensor_proto.float_data_size(), 2 * 3 * 4);
|
||||
for (int i = 0; i < 2 * 3 * 4; ++i) {
|
||||
EXPECT_EQ(tensor_proto.float_data(i), static_cast<float>(i));
|
||||
}
|
||||
Blob new_blob;
|
||||
EXPECT_TRUE(new_blob.Deserialize(serialized));
|
||||
EXPECT_TRUE(new_blob.IsType<MKLMemory<float>>());
|
||||
const auto& new_mkl_memory = blob.Get<MKLMemory<float>>();
|
||||
EXPECT_EQ(new_mkl_memory.dims().size(), 3);
|
||||
EXPECT_EQ(new_mkl_memory.dims()[0], 2);
|
||||
EXPECT_EQ(new_mkl_memory.dims()[1], 3);
|
||||
EXPECT_EQ(new_mkl_memory.dims()[2], 4);
|
||||
float recovered_data[2 * 3 * 4];
|
||||
new_mkl_memory.CopyTo(recovered_data);
|
||||
for (int i = 0; i < 2 * 3 * 4; ++i) {
|
||||
EXPECT_EQ(recovered_data[i], i);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_HAS_MKL_DNN
|
||||
127
caffe2/mkl/operators/operator_fallback_mkl.h
Normal file
127
caffe2/mkl/operators/operator_fallback_mkl.h
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
#ifndef CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_
|
||||
#define CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_
|
||||
|
||||
#include "caffe2/core/common.h"
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/proto/caffe2.pb.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
/**
|
||||
* @brief A templated class to allow one to wrap a CPU operator as an MKL
|
||||
* operator.
|
||||
*
|
||||
* This class can be used when one does not have the MKL implementation ready
|
||||
* yet for an operator. Essentially, what this op does is to automatically
|
||||
* deal with data copy for you. Plausibly, this causes a lot of overhead and
|
||||
* is not optimal, so you should use this operator mostly for quick prototyping
|
||||
* purpose.
|
||||
*
|
||||
* All the input and output of the original operator should be TensorCPU.
|
||||
*
|
||||
* Example usage: if you have a class MyMagicOp that is CPU based, and you use
|
||||
* the registration code
|
||||
* REGISTER_CPU_OPERATOR(MyMagic, MyMagicOp);
|
||||
* to register the CPU side, you can create its corresponding MKL operator
|
||||
* (with performance hits of course) via
|
||||
* REGISTER_MKL_OPERATOR(MyMagic,
|
||||
* MKLFallbackOp<MyMagicOp>);
|
||||
*
|
||||
* Advanced usage: if you want to have some specific outputs never copied, you
|
||||
* can use the SkipOutputCopy template argument to do that. For example, if
|
||||
* MyMagic produces two outputs and the first output is always going to live on
|
||||
* the CPU, you can do
|
||||
* REGISTER_CUDA_OPERATOR(MyMagic,
|
||||
* MKLFallbackOp<MyMagicOp, SkipIndices<0>>);
|
||||
*/
|
||||
template <class CPUOp, typename SkipOutputCopy = SkipIndices<>>
|
||||
class MKLFallbackOp final : public Operator<MKLContext> {
|
||||
public:
|
||||
USE_OPERATOR_FUNCTIONS(MKLContext);
|
||||
MKLFallbackOp(const OperatorDef& def, Workspace* ws)
|
||||
: Operator<MKLContext>(def, ws) {
|
||||
CAFFE_ENFORCE_EQ(def.device_option().device_type(), MKLDNN);
|
||||
OperatorDef base_def_(def);
|
||||
// base_def_ runs on CPU, so we will set its device option to CPU.
|
||||
base_def_.clear_device_option();
|
||||
base_def_.mutable_device_option()->set_device_type(CPU);
|
||||
// Set up the symbols for the local workspace.
|
||||
for (const string& name : def.input()) {
|
||||
local_input_blobs_.push_back(local_ws_.CreateBlob(name));
|
||||
CHECK_NOTNULL(local_input_blobs_.back());
|
||||
}
|
||||
base_op_.reset(new CPUOp(base_def_, &local_ws_));
|
||||
for (const string& name : def.output()) {
|
||||
local_output_blobs_.push_back(local_ws_.GetBlob(name));
|
||||
CHECK_NOTNULL(local_output_blobs_.back());
|
||||
}
|
||||
}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
for (int i = 0; i < InputSize(); ++i) {
|
||||
if (OperatorBase::InputIsType<MKLMemory<float>>(i)) {
|
||||
OperatorBase::Input<MKLMemory<float>>(i).CopyTo(
|
||||
local_input_blobs_[i]->template GetMutable<TensorCPU>());
|
||||
} else if (OperatorBase::InputIsType<MKLMemory<double>>(i)) {
|
||||
OperatorBase::Input<MKLMemory<double>>(i).CopyTo(
|
||||
local_input_blobs_[i]->template GetMutable<TensorCPU>());
|
||||
} else {
|
||||
VLOG(1) << "Input " << i << " is not MKLMemory. Skipping copy.";
|
||||
// Note(jiayq): This removes a const but conceptually
|
||||
// local_input_blobs will only be used as const blob input for the
|
||||
// base op so we are still fine.
|
||||
local_input_blobs_[i]->ShareExternal(
|
||||
const_cast<void*>(OperatorBase::Inputs()[i]->GetRaw()),
|
||||
OperatorBase::Inputs()[i]->meta());
|
||||
}
|
||||
}
|
||||
|
||||
if (!base_op_->Run()) {
|
||||
LOG(ERROR) << "Base op run failed in MKLFallbackOp. Def: "
|
||||
<< ProtoDebugString(def());
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < OutputSize(); ++i) {
|
||||
if (SkipOutputCopy::Contains(i)) {
|
||||
VLOG(1) << "Copy output: index " << i << " skipped.";
|
||||
continue;
|
||||
}
|
||||
CAFFE_ENFORCE(
|
||||
local_output_blobs_[i]->template IsType<TensorCPU>(),
|
||||
"MKL fallback op currently does not support non-TensorCPU "
|
||||
"output type who needs copying.");
|
||||
const auto& src = local_output_blobs_[i]->template Get<TensorCPU>();
|
||||
if (src.IsType<float>()) {
|
||||
Blob& dst = OperatorBase::OutputAt(i);
|
||||
if (!dst.IsType<MKLMemory<float>>() ||
|
||||
dst.Get<MKLMemory<float>>().dims() != src.dims()) {
|
||||
dst.Reset(new MKLMemory<float>(src.dims());
|
||||
}
|
||||
dst.GetMutable < MKLMemory<float>()->CopyFrom(src);
|
||||
} else if (src.IsType<double>()) {
|
||||
Blob& dst = OperatorBase::OutputAt(i);
|
||||
if (!dst.IsType<MKLMemory<double>>() ||
|
||||
dst.Get<MKLMemory<double>>().dims() != src.dims()) {
|
||||
dst.Reset(new MKLMemory<double>(src.dims());
|
||||
}
|
||||
dst.GetMutable < MKLMemory<double>()->CopyFrom(src);
|
||||
} else {
|
||||
CAFFE_THROW("MKLMemory only supports float and double.");
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
Workspace local_ws_;
|
||||
vector<Blob*> local_input_blobs_;
|
||||
vector<Blob*> local_output_blobs_;
|
||||
std::unique_ptr<CPUOp> base_op_;
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_
|
||||
|
|
@ -97,9 +97,14 @@ class ConvPoolOpBase : public Operator<Context> {
|
|||
// it may not be identical to the input channels.
|
||||
// This function can be used in the forward functions to obtain the output
|
||||
// sizes.
|
||||
// Note(jiayq): the templatization of this function is mainly to help
|
||||
// implementations that do not use first-class Tensor objects, such as the
|
||||
// MKL operator. One can still call this function with dummy
|
||||
// Tensor<CPUContext> objects in order to obtain the sizes.
|
||||
template <typename AlternativeContext>
|
||||
void SetOutputSize(
|
||||
const Tensor<Context>& input,
|
||||
Tensor<Context>* output,
|
||||
const Tensor<AlternativeContext>& input,
|
||||
Tensor<AlternativeContext>* output,
|
||||
int output_channel) {
|
||||
CAFFE_ENFORCE(4 == input.ndim());
|
||||
CAFFE_ENFORCE(input.size() > 0);
|
||||
|
|
@ -119,7 +124,7 @@ class ConvPoolOpBase : public Operator<Context> {
|
|||
W = input.dim32(3);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown Storage order: " << order_;
|
||||
CAFFE_THROW("Unknown Storage order: ", order_);
|
||||
}
|
||||
|
||||
int output_height = 0, output_width = 0;
|
||||
|
|
@ -191,7 +196,7 @@ class ConvPoolOpBase : public Operator<Context> {
|
|||
// VLOG(2) << "Running NCHW";
|
||||
return RunOnDeviceWithOrderNCHW();
|
||||
default:
|
||||
LOG(FATAL) << "Unknown storage order: " << order_;
|
||||
CAFFE_THROW("Unknown Storage order: ", order_);
|
||||
}
|
||||
// To suppress old compiler warnings
|
||||
return true;
|
||||
|
|
|
|||
|
|
@ -9,32 +9,6 @@
|
|||
|
||||
namespace caffe2 {
|
||||
|
||||
template <int... values>
|
||||
class SkipIndices {
|
||||
private:
|
||||
template <int V>
|
||||
static inline bool ContainsInternal(const int i) {
|
||||
return (i == V);
|
||||
}
|
||||
template <int First, int Second, int... Rest>
|
||||
static inline bool ContainsInternal(const int i) {
|
||||
return (i == First) && ContainsInternal<Second, Rest...>(i);
|
||||
}
|
||||
|
||||
public:
|
||||
static inline bool Contains(const int i) {
|
||||
return ContainsInternal<values...>(i);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class SkipIndices<> {
|
||||
public:
|
||||
static inline bool Contains(const int i) {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief A templated class to allow one to wrap a CPU operator as a CUDA
|
||||
* operator.
|
||||
|
|
|
|||
|
|
@ -116,6 +116,8 @@ class LayoutWrapper {
|
|||
/**
|
||||
* @brief A wrapper around an opaque MKL internal resource that has certain
|
||||
* layouts and convertion primitives set up.
|
||||
*
|
||||
* Most of the MKLMemory functions are not thread safe.
|
||||
*/
|
||||
template <typename T>
|
||||
class MKLMemory {
|
||||
|
|
@ -131,6 +133,29 @@ class MKLMemory {
|
|||
const dnnPrimitive_t primitive = nullptr,
|
||||
const dnnResourceType_t type = dnnResourceNumber,
|
||||
bool share_mem_if_possible = false) {
|
||||
Reset(dimension, size, strides, primitive, type, share_mem_if_possible);
|
||||
}
|
||||
|
||||
// Initialize an MKLMemory, with the given dimension assuming a C-contiguous
|
||||
// storage.
|
||||
template <typename IndexType>
|
||||
explicit MKLMemory(
|
||||
const vector<IndexType>& dims,
|
||||
const dnnPrimitive_t primitive = nullptr,
|
||||
const dnnResourceType_t type = dnnResourceNumber,
|
||||
bool share_mem_if_possible = false) {
|
||||
Reset(dims, primitive, type, share_mem_if_possible);
|
||||
}
|
||||
|
||||
// Initialize an MKLMemory with the given size, strides, dnn
|
||||
// primitive and type.
|
||||
void Reset(
|
||||
const size_t dimension,
|
||||
const size_t size[],
|
||||
const size_t strides[],
|
||||
const dnnPrimitive_t primitive = nullptr,
|
||||
const dnnResourceType_t type = dnnResourceNumber,
|
||||
bool share_mem_if_possible = false) {
|
||||
dims_.resize(dimension);
|
||||
for (int i = 0; i < dimension; ++i) {
|
||||
dims_[i] = size[dimension - 1 - i];
|
||||
|
|
@ -143,27 +168,22 @@ class MKLMemory {
|
|||
}
|
||||
convert_in_.Reset(dnnConversionCreate<T>, user_layout_, layout_);
|
||||
convert_out_.Reset(dnnConversionCreate<T>, layout_, user_layout_);
|
||||
share_mem_ =
|
||||
share_mem_if_possible && dnnLayoutCompare(layout_, user_layout_);
|
||||
if (!share_mem_) {
|
||||
// If we do not do copy, we will create the buffer and own it.
|
||||
void* allocated = nullptr;
|
||||
MKLDNN_SAFE_CALL(dnnAllocateBuffer<T>(&allocated, layout_));
|
||||
buffer_.reset(allocated, [](void* ptr) -> void {
|
||||
MKLDNN_CHECK(dnnReleaseBuffer<T>(ptr));
|
||||
});
|
||||
}
|
||||
share_mem_if_possible_ = share_mem_if_possible;
|
||||
layout_is_user_layout_ = dnnLayoutCompare<T>(layout_, user_layout_);
|
||||
}
|
||||
|
||||
// Initialize an MKLMemory, with the given dimension assuming a C-contiguous
|
||||
// storage.
|
||||
template <typename IndexType>
|
||||
explicit MKLMemory(
|
||||
void Reset(
|
||||
const vector<IndexType>& dims,
|
||||
const dnnPrimitive_t primitive = nullptr,
|
||||
const dnnResourceType_t type = dnnResourceNumber,
|
||||
bool share_mem_if_possible = false) {
|
||||
dims_ = dims;
|
||||
dims_.resize(dims.size());
|
||||
for (int i = 0; i < dims.size(); ++i) {
|
||||
dims_[i] = dims[i];
|
||||
}
|
||||
size_t dimension = dims.size();
|
||||
size_t size[dimension];
|
||||
size_t strides[dimension];
|
||||
|
|
@ -179,27 +199,19 @@ class MKLMemory {
|
|||
}
|
||||
convert_in_.Reset(dnnConversionCreate<T>, user_layout_, layout_);
|
||||
convert_out_.Reset(dnnConversionCreate<T>, layout_, user_layout_);
|
||||
share_mem_ =
|
||||
share_mem_if_possible && dnnLayoutCompare<T>(layout_, user_layout_);
|
||||
if (!share_mem_) {
|
||||
// If we do not do copy, we will create the buffer and own it.
|
||||
void* allocated = nullptr;
|
||||
MKLDNN_SAFE_CALL(dnnAllocateBuffer<T>(&allocated, layout_));
|
||||
buffer_.reset(allocated, [](void* ptr) -> void {
|
||||
MKLDNN_CHECK(dnnReleaseBuffer<T>(ptr));
|
||||
});
|
||||
}
|
||||
share_mem_if_possible_ = share_mem_if_possible;
|
||||
layout_is_user_layout_ = dnnLayoutCompare<T>(layout_, user_layout_);
|
||||
}
|
||||
|
||||
// Destructs the MKLMemory.
|
||||
~MKLMemory() {}
|
||||
|
||||
void CopyFrom(const void* ptr) {
|
||||
if (share_mem_) {
|
||||
if (share_mem_if_possible_ && layout_is_user_layout_) {
|
||||
buffer_.reset(const_cast<void*>(ptr), [](void*) -> void {});
|
||||
} else {
|
||||
MKLDNN_SAFE_CALL(dnnConversionExecute<T>(
|
||||
convert_in_, const_cast<void*>(ptr), buffer_.get()));
|
||||
convert_in_, const_cast<void*>(ptr), buffer()));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -211,8 +223,19 @@ class MKLMemory {
|
|||
CopyFrom(tensor.template data<T>());
|
||||
}
|
||||
|
||||
void CopyFrom(const MKLMemory<T>& other) {
|
||||
if (share_mem_if_possible_ && dnnLayoutCompare(other.layout_, layout_)) {
|
||||
buffer_ = other.buffer_;
|
||||
} else {
|
||||
PrimitiveWrapper<T> convert(
|
||||
dnnConversionCreate<T>, other.layout_, layout_);
|
||||
MKLDNN_SAFE_CALL(
|
||||
dnnConversionExecute<T>(convert, other.buffer_, buffer()));
|
||||
}
|
||||
}
|
||||
|
||||
bool ShareFrom(const void* ptr) {
|
||||
if (share_mem_) {
|
||||
if (share_mem_if_possible_ && layout_is_user_layout_) {
|
||||
buffer_.reset(const_cast<void*>(ptr), [](void*) -> void {});
|
||||
return true;
|
||||
} else {
|
||||
|
|
@ -228,13 +251,22 @@ class MKLMemory {
|
|||
return ShareFrom(tensor.template data<T>());
|
||||
}
|
||||
|
||||
bool ShareFrom(const MKLMemory<T>& other) {
|
||||
if (share_mem_if_possible_ && dnnLayoutCompare(other.layout_, layout_)) {
|
||||
buffer_ = other.buffer_;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void CopyTo(void* ptr) const {
|
||||
if (buffer_.get() == ptr) {
|
||||
// This is already mapping to the same memory region. Skip copy.
|
||||
return;
|
||||
}
|
||||
CAFFE_ENFORCE(
|
||||
buffer_.get(), "Canot copy out from an empty internal resource.");
|
||||
buffer_.get(), "Canot copy out from an uninitialized MKLMemory.");
|
||||
MKLDNN_SAFE_CALL(dnnConversionExecute<T>(convert_out_, buffer_.get(), ptr));
|
||||
}
|
||||
|
||||
|
|
@ -247,7 +279,29 @@ class MKLMemory {
|
|||
CopyTo(tensor->mutable_data<T>());
|
||||
}
|
||||
|
||||
void CopyTo(MKLMemory<T>* other) {
|
||||
if (buffer_.get() == other->buffer_.get()) {
|
||||
return;
|
||||
}
|
||||
CAFFE_ENFORCE(
|
||||
buffer_.get(), "Canot copy out from an uninitialized MKLMemory.");
|
||||
// TODO(jiayq): if primitive creation is a big overhead and we will be
|
||||
// consistently copying stuff with fixed src and dst layouts, consider
|
||||
// making a cache for the primitive below.
|
||||
PrimitiveWrapper<T> convert(
|
||||
dnnConversionCreate<T>, layout_, other->layout_);
|
||||
MKLDNN_SAFE_CALL(
|
||||
dnnConversionExecute<T>(convert, buffer_, other->buffer()));
|
||||
}
|
||||
|
||||
inline void* buffer() {
|
||||
if (buffer_ == nullptr) {
|
||||
void* allocated = nullptr;
|
||||
MKLDNN_SAFE_CALL(dnnAllocateBuffer<T>(&allocated, layout_));
|
||||
buffer_.reset(allocated, [](void* ptr) -> void {
|
||||
MKLDNN_CHECK(dnnReleaseBuffer<T>(ptr));
|
||||
});
|
||||
}
|
||||
return buffer_.get();
|
||||
}
|
||||
|
||||
|
|
@ -279,7 +333,8 @@ class MKLMemory {
|
|||
}
|
||||
|
||||
private:
|
||||
bool share_mem_;
|
||||
bool share_mem_if_possible_;
|
||||
bool layout_is_user_layout_;
|
||||
// The internal buffer in the specific dnn layout.
|
||||
std::shared_ptr<void> buffer_;
|
||||
// The dimensions in the same order as Caffe2 does. This is used to
|
||||
|
|
|
|||
Loading…
Reference in a new issue