mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9939 Pull Request resolved: https://github.com/facebookresearch/weakly-supervised-action-detection/pull/13 Pull Request resolved: https://github.com/pytorch/translate/pull/166 Pull Request resolved: https://github.com/pytorch/pytorch/pull/9125 Closes https://github.com/pytorch/pytorch/pull/9125 Use inheritance for polymorphism, and remove template parameter This is to change the templating in call sites, the core implementations will change later Before Caffe2 Tensor class was compile-time fixed to bind to a particular device/context. With this change, we're making it a runtime property (stored inside the tensor), but preserve the same semantics. For example, one has to specify device type in order to create a Tensor - there are no uninitialized tensors. More specifically the changes are: 1. We added an extra argument *DeviceType* to most of the constructors of the tensor, e.g. (Tensor(DeviceType type)), 2. Semantics of constructor Tensor(const Tensor<SrcContext>& src, ContextForCopy* context); is changed, in this constructor, the second context is passed in to enable us to call the templated Copy function, it could be in a different context as source and target previously, now we'll enforce that the context should have same device type as src, if it is provided. 3. To preserve 'get-or-construct' semantics of Blob, we added specialized getter Blob::GetMutableTensor that verifies both that Blob contains a Tensor and that it's of a correct type 4. Specifically, Tensor type is not default-constructible any more (as we don't have unknown device tensors) and thus some of the code handling STL containers needs to change Note: Some changes are postponed just to keep this diff a bit smaller. Please see `TODO`s. Reviewed By: ezyang, houseroad Differential Revision: D9024330 fbshipit-source-id: e0b8295d2dc6ebe2963383ded5af799ad17164ba
234 lines
5.8 KiB
C++
234 lines
5.8 KiB
C++
#include "rebatching_queue.h"
|
|
#include "caffe2/utils/smart_tensor_printer.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
namespace {
|
|
|
|
// This concat function will always create a new first dimension to concat
|
|
void concat(
|
|
CPUContext& context,
|
|
const std::vector<std::vector<TensorCPU>>& inputs,
|
|
const std::vector<TensorCPU*>& outputs) {
|
|
CAFFE_ENFORCE(!inputs.empty());
|
|
|
|
const auto& inputZero = inputs[0];
|
|
const auto numTensors = inputZero.size();
|
|
const auto numRows = inputs.size();
|
|
|
|
// Precompute the output sizes to avoid resizing
|
|
std::vector<std::vector<TIndex>> outputDims(numTensors);
|
|
|
|
for (int i = 0; i < numTensors; ++i) {
|
|
SmartTensorPrinter::PrintTensor(inputZero.at(i));
|
|
outputDims[i] = inputZero.at(i).dims();
|
|
outputDims[i].insert(outputDims[i].begin(), numRows);
|
|
}
|
|
|
|
// Resize to the final output size
|
|
std::vector<void*> destinations(numTensors);
|
|
for (int i = 0; i < numTensors; ++i) {
|
|
outputs[i]->Resize(outputDims[i]);
|
|
destinations[i] = outputs[i]->raw_mutable_data(inputZero[i].meta());
|
|
}
|
|
|
|
for (int i = 0; i < numRows; ++i) {
|
|
CAFFE_ENFORCE_EQ(inputs[i].size(), numTensors);
|
|
|
|
for (int j = 0; j < numTensors; ++j) {
|
|
const auto& input = inputs[i][j];
|
|
|
|
CAFFE_ENFORCE(inputZero[j].meta() == input.meta());
|
|
CAFFE_ENFORCE_EQ(inputZero[j].itemsize(), input.itemsize());
|
|
CAFFE_ENFORCE_EQ(inputZero[j].ndim(), input.ndim());
|
|
for (int k = 0; k < input.ndim(); ++k) {
|
|
CAFFE_ENFORCE_EQ(input.dims()[k], inputZero[j].dims()[k]);
|
|
}
|
|
|
|
// Skip empty tensors
|
|
if (input.size() == 0) {
|
|
continue;
|
|
}
|
|
|
|
context.CopyItemsToCPU(
|
|
input.meta(),
|
|
input.size(),
|
|
input.raw_data() /* src */,
|
|
destinations[j] /* dst */
|
|
);
|
|
|
|
destinations[j] =
|
|
(char*)destinations[j] + input.size() * input.itemsize();
|
|
}
|
|
}
|
|
}
|
|
|
|
std::vector<std::vector<TensorCPU>> split(
|
|
CPUContext& context,
|
|
const std::vector<const TensorCPU*>& inputs) {
|
|
CAFFE_ENFORCE(!inputs.empty());
|
|
|
|
const auto outputSize = inputs[0]->dims().at(0);
|
|
std::vector<std::vector<TensorCPU>> outputs(outputSize);
|
|
|
|
for (const auto* inputPtr : inputs) {
|
|
CAFFE_ENFORCE(inputPtr);
|
|
|
|
const auto& input = *inputPtr;
|
|
const auto innerSize = input.size_from_dim(1);
|
|
const auto itemSize = input.meta().itemsize();
|
|
|
|
auto outputDims = input.dims();
|
|
CAFFE_ENFORCE(!outputDims.empty());
|
|
outputDims.erase(outputDims.begin());
|
|
CAFFE_ENFORCE_EQ(input.dims().at(0), outputSize);
|
|
|
|
for (int i = 0; i < outputSize; ++i) {
|
|
outputs[i].push_back(Tensor(outputDims, CPU));
|
|
context.CopyItemsToCPU(
|
|
input.meta(),
|
|
innerSize,
|
|
(char*)input.raw_data() + i * innerSize * itemSize /* src */,
|
|
outputs[i].back().raw_mutable_data(input.meta()) /* dst */);
|
|
}
|
|
}
|
|
|
|
return outputs;
|
|
}
|
|
} // anonymous namespace
|
|
|
|
RebatchingQueue::RebatchingQueue(size_t capacity, size_t numBlobs)
|
|
: capacity_(capacity), numBlobs_(numBlobs), queue_(capacity) {}
|
|
|
|
RebatchingQueue::~RebatchingQueue() {
|
|
close();
|
|
}
|
|
|
|
bool RebatchingQueue::canRead() const {
|
|
return tail_ < head_;
|
|
}
|
|
|
|
bool RebatchingQueue::dequeue(
|
|
CPUContext& context,
|
|
size_t numElements,
|
|
const std::vector<TensorCPU*>& outputs) {
|
|
std::vector<std::vector<TensorCPU>> results;
|
|
results.reserve(numElements);
|
|
|
|
for (;;) {
|
|
if (results.size() == numElements) {
|
|
break;
|
|
}
|
|
|
|
{
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
|
|
cvEmpty_.wait(lock, [this] { return canRead() || isClosed_; });
|
|
|
|
// We only want to stop reading if the queue is empty and closed
|
|
if (!canRead() && isClosed_) {
|
|
break;
|
|
}
|
|
|
|
do {
|
|
results.push_back(std::move(queue_[tail_++ % capacity()]));
|
|
} while (canRead() && results.size() < numElements);
|
|
}
|
|
|
|
if (numElements == 1) {
|
|
cvOverflow_.notify_one();
|
|
} else {
|
|
cvOverflow_.notify_all();
|
|
}
|
|
}
|
|
|
|
if (results.empty()) {
|
|
return false;
|
|
}
|
|
|
|
concat(context, results, outputs);
|
|
|
|
return true;
|
|
}
|
|
|
|
bool RebatchingQueue::canWrite() const {
|
|
return tail_ + capacity() > head_;
|
|
}
|
|
|
|
bool RebatchingQueue::enqueueOne(
|
|
CPUContext& /*context*/,
|
|
const std::vector<const TensorCPU*>& inputs) {
|
|
std::vector<std::vector<TensorCPU>> splittedInputs;
|
|
splittedInputs.emplace_back();
|
|
auto& tensorVector = splittedInputs.back();
|
|
tensorVector.reserve(inputs.size());
|
|
for (const auto* tensorPtr : inputs) {
|
|
tensorVector.push_back(tensorPtr->Clone());
|
|
}
|
|
|
|
return enqueue(std::move(splittedInputs));
|
|
}
|
|
|
|
bool RebatchingQueue::enqueueMany(
|
|
CPUContext& context,
|
|
const std::vector<const TensorCPU*>& inputs) {
|
|
CAFFE_ENFORCE_EQ(numBlobs_, inputs.size());
|
|
|
|
std::vector<std::vector<TensorCPU>> splittedInputs;
|
|
splittedInputs = split(context, inputs);
|
|
return enqueue(std::move(splittedInputs));
|
|
}
|
|
|
|
bool RebatchingQueue::enqueue(
|
|
std::vector<std::vector<TensorCPU>> splittedInputs) {
|
|
int idx = 0;
|
|
for (;;) {
|
|
if (idx >= splittedInputs.size()) {
|
|
break;
|
|
}
|
|
|
|
{
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
|
|
cvOverflow_.wait(lock, [this] { return canWrite() || isClosed_; });
|
|
|
|
if (isClosed_) {
|
|
// If we are here it means that we didn't apply the entire batch and if
|
|
// we get closed in the middle of enquing we treat it as a non-success.
|
|
return false;
|
|
}
|
|
|
|
do {
|
|
queue_[head_++ % capacity()] = std::move(splittedInputs[idx++]);
|
|
} while (canWrite() && idx < splittedInputs.size());
|
|
}
|
|
|
|
cvEmpty_.notify_all();
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
size_t RebatchingQueue::capacity() const {
|
|
return capacity_;
|
|
}
|
|
|
|
size_t RebatchingQueue::numBlobs() const {
|
|
return numBlobs_;
|
|
}
|
|
|
|
bool RebatchingQueue::isClosed() const {
|
|
std::lock_guard<std::mutex> g(mutex_);
|
|
return isClosed_;
|
|
}
|
|
|
|
void RebatchingQueue::close() {
|
|
{
|
|
std::lock_guard<std::mutex> g(mutex_);
|
|
isClosed_ = true;
|
|
}
|
|
|
|
cvEmpty_.notify_all();
|
|
cvOverflow_.notify_all();
|
|
}
|
|
} // caffe2
|