diff --git a/caffe2/core/tensor.h b/caffe2/core/tensor.h index 2fc13173954..ed57bec7349 100644 --- a/caffe2/core/tensor.h +++ b/caffe2/core/tensor.h @@ -160,7 +160,6 @@ class Tensor { template void Extend(TIndex num, float growthPct, ContextForCopy* context) { CAFFE_ENFORCE_GE(dims_.size(), 1); - auto oldSize = size_; auto newDims = dims_; newDims[0] += num; if (!data_) { @@ -169,20 +168,37 @@ class Tensor { } auto newSize = std::accumulate( newDims.begin(), newDims.end(), 1, std::multiplies()); - if (newSize * meta_.itemsize() > capacity_) { - auto newCapacity = dims_; - newCapacity[0] = std::max( - newDims[0], std::ceil(dims_[0] * (growthPct + 100) / 100)); - auto oldData = std::move(data_); - Resize(newCapacity); - auto* newData = raw_mutable_data(meta_); - context->template CopyItems( - meta_, oldSize, oldData.get(), newData); + if (newSize * meta_.itemsize() <= capacity_) { + dims_ = newDims; + size_ = newSize; + return; } + auto newCapacity = dims_; + newCapacity[0] = std::max( + newDims[0], std::ceil(dims_[0] * (growthPct + 100) / 100)); + Reserve(newCapacity, context); dims_ = newDims; size_ = newSize; } + template + void Reserve(const std::vector& newCapacity, ContextForCopy* context) { + auto newSize = std::accumulate( + newCapacity.begin(), newCapacity.end(), 1, std::multiplies()); + if (newSize * meta_.itemsize() <= capacity_) { + return; + } + auto oldData = std::move(data_); + auto oldSize = size_; + auto oldDims = dims_; + Resize(newCapacity); + auto* newData = raw_mutable_data(meta_); + context->template CopyItems( + meta_, oldSize, oldData.get(), newData); + dims_ = oldDims; + size_ = oldSize; + } + /** * @brief Shrinks the outer-most dimension to given size, keeping the data. *