Create a reserve operation for tensors to avoid reallocating memory on Extend() and Resize() operations

Summary: I want to collect tensors over multiple batches and so this operation could become helpful to allocate enough memory from the beginning

Reviewed By: dzhulgakov

Differential Revision: D4216198

fbshipit-source-id: e6b67cc7d80d71455487878da9b6b7a225035085
This commit is contained in:
Maxime Boucher 2016-11-29 02:54:53 -08:00 committed by Bram Wasti
parent 1aafeb3565
commit f16c2fe3da

View file

@ -160,7 +160,6 @@ class Tensor {
template <class ContextForCopy>
void Extend(TIndex num, float growthPct, ContextForCopy* context) {
CAFFE_ENFORCE_GE(dims_.size(), 1);
auto oldSize = size_;
auto newDims = dims_;
newDims[0] += num;
if (!data_) {
@ -169,20 +168,37 @@ class Tensor {
}
auto newSize = std::accumulate(
newDims.begin(), newDims.end(), 1, std::multiplies<TIndex>());
if (newSize * meta_.itemsize() > capacity_) {
auto newCapacity = dims_;
newCapacity[0] = std::max<size_t>(
newDims[0], std::ceil(dims_[0] * (growthPct + 100) / 100));
auto oldData = std::move(data_);
Resize(newCapacity);
auto* newData = raw_mutable_data(meta_);
context->template CopyItems<ContextForCopy, ContextForCopy>(
meta_, oldSize, oldData.get(), newData);
if (newSize * meta_.itemsize() <= capacity_) {
dims_ = newDims;
size_ = newSize;
return;
}
auto newCapacity = dims_;
newCapacity[0] = std::max<size_t>(
newDims[0], std::ceil(dims_[0] * (growthPct + 100) / 100));
Reserve(newCapacity, context);
dims_ = newDims;
size_ = newSize;
}
template <class T, class ContextForCopy>
void Reserve(const std::vector<T>& newCapacity, ContextForCopy* context) {
auto newSize = std::accumulate(
newCapacity.begin(), newCapacity.end(), 1, std::multiplies<TIndex>());
if (newSize * meta_.itemsize() <= capacity_) {
return;
}
auto oldData = std::move(data_);
auto oldSize = size_;
auto oldDims = dims_;
Resize(newCapacity);
auto* newData = raw_mutable_data(meta_);
context->template CopyItems<ContextForCopy, ContextForCopy>(
meta_, oldSize, oldData.get(), newData);
dims_ = oldDims;
size_ = oldSize;
}
/**
* @brief Shrinks the outer-most dimension to given size, keeping the data.
*