mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Create a reserve operation for tensors to avoid reallocating memory on Extend() and Resize() operations
Summary: I want to collect tensors over multiple batches and so this operation could become helpful to allocate enough memory from the beginning Reviewed By: dzhulgakov Differential Revision: D4216198 fbshipit-source-id: e6b67cc7d80d71455487878da9b6b7a225035085
This commit is contained in:
parent
1aafeb3565
commit
f16c2fe3da
1 changed files with 26 additions and 10 deletions
|
|
@ -160,7 +160,6 @@ class Tensor {
|
|||
template <class ContextForCopy>
|
||||
void Extend(TIndex num, float growthPct, ContextForCopy* context) {
|
||||
CAFFE_ENFORCE_GE(dims_.size(), 1);
|
||||
auto oldSize = size_;
|
||||
auto newDims = dims_;
|
||||
newDims[0] += num;
|
||||
if (!data_) {
|
||||
|
|
@ -169,20 +168,37 @@ class Tensor {
|
|||
}
|
||||
auto newSize = std::accumulate(
|
||||
newDims.begin(), newDims.end(), 1, std::multiplies<TIndex>());
|
||||
if (newSize * meta_.itemsize() > capacity_) {
|
||||
auto newCapacity = dims_;
|
||||
newCapacity[0] = std::max<size_t>(
|
||||
newDims[0], std::ceil(dims_[0] * (growthPct + 100) / 100));
|
||||
auto oldData = std::move(data_);
|
||||
Resize(newCapacity);
|
||||
auto* newData = raw_mutable_data(meta_);
|
||||
context->template CopyItems<ContextForCopy, ContextForCopy>(
|
||||
meta_, oldSize, oldData.get(), newData);
|
||||
if (newSize * meta_.itemsize() <= capacity_) {
|
||||
dims_ = newDims;
|
||||
size_ = newSize;
|
||||
return;
|
||||
}
|
||||
auto newCapacity = dims_;
|
||||
newCapacity[0] = std::max<size_t>(
|
||||
newDims[0], std::ceil(dims_[0] * (growthPct + 100) / 100));
|
||||
Reserve(newCapacity, context);
|
||||
dims_ = newDims;
|
||||
size_ = newSize;
|
||||
}
|
||||
|
||||
template <class T, class ContextForCopy>
|
||||
void Reserve(const std::vector<T>& newCapacity, ContextForCopy* context) {
|
||||
auto newSize = std::accumulate(
|
||||
newCapacity.begin(), newCapacity.end(), 1, std::multiplies<TIndex>());
|
||||
if (newSize * meta_.itemsize() <= capacity_) {
|
||||
return;
|
||||
}
|
||||
auto oldData = std::move(data_);
|
||||
auto oldSize = size_;
|
||||
auto oldDims = dims_;
|
||||
Resize(newCapacity);
|
||||
auto* newData = raw_mutable_data(meta_);
|
||||
context->template CopyItems<ContextForCopy, ContextForCopy>(
|
||||
meta_, oldSize, oldData.get(), newData);
|
||||
dims_ = oldDims;
|
||||
size_ = oldSize;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Shrinks the outer-most dimension to given size, keeping the data.
|
||||
*
|
||||
|
|
|
|||
Loading…
Reference in a new issue