mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
caffe2: rebatching queue for MultiTask
Summary: RFC. This is a naive implementation of Rebatchin Queue for MultiTask effort. Full disclaimer, I'm very new to Caffe/Machine Learning and I'm doing dodge science here (under Dmytros supervision), so please be extra tough on this review so I can learn best practices :) Differential Revision: D4871970 fbshipit-source-id: 924820ef0fce45b5e2bdabeec9885cbafa23a880
This commit is contained in:
parent
222b781f76
commit
ee7b3c9b2b
5 changed files with 743 additions and 0 deletions
287
caffe2/python/operator_test/rebatching_queue_test.py
Normal file
287
caffe2/python/operator_test/rebatching_queue_test.py
Normal file
|
|
@ -0,0 +1,287 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
from caffe2.python import core, workspace
|
||||
from caffe2.python.test_util import TestCase
|
||||
|
||||
import numpy as np
|
||||
import numpy.testing as npt
|
||||
|
||||
from hypothesis import given
|
||||
import hypothesis.strategies as st
|
||||
|
||||
import functools
|
||||
|
||||
|
||||
def primefac(n):
|
||||
ret = []
|
||||
divisor = 2
|
||||
while divisor * divisor <= n:
|
||||
while (n % divisor) == 0:
|
||||
ret.append(divisor)
|
||||
n = n // divisor
|
||||
divisor = divisor + 1
|
||||
if n > 1:
|
||||
ret.append(n)
|
||||
return ret
|
||||
|
||||
|
||||
class TestReBatchingQueue(TestCase):
|
||||
def test_rebatching_queue_single_enqueue_dequeue(self):
|
||||
net = core.Net('net')
|
||||
|
||||
tensors = [
|
||||
net.ConstantFill([], 1, value=1.0, run_once=False)
|
||||
for times in range(3)
|
||||
]
|
||||
|
||||
queue = net.CreateRebatchingQueue([], 1, capacity=10, num_blobs=1)
|
||||
|
||||
net.EnqueueRebatchingQueue([queue, tensors[0]], [])
|
||||
net.EnqueueRebatchingQueue([queue, tensors[1]], [])
|
||||
net.EnqueueRebatchingQueue([queue, tensors[2]], [])
|
||||
|
||||
results = [
|
||||
net.DequeueRebatchingQueue([queue], 1),
|
||||
net.DequeueRebatchingQueue([queue], 1),
|
||||
net.DequeueRebatchingQueue([queue], 1),
|
||||
]
|
||||
|
||||
workspace.RunNetOnce(net)
|
||||
|
||||
for idx in range(3):
|
||||
self.assertEquals(workspace.FetchBlob(results[idx]), [1.0])
|
||||
|
||||
def test_rebatching_queue_multi_enqueue_dequeue(self):
|
||||
net = core.Net('net')
|
||||
workspace.FeedBlob(
|
||||
"tensors", np.array([x for x in range(10)], np.int32)
|
||||
)
|
||||
|
||||
queue = net.CreateRebatchingQueue([], 1, capacity=10, num_blobs=1)
|
||||
|
||||
net.EnqueueRebatchingQueue([queue, "tensors"], [], enqueue_batch=True)
|
||||
|
||||
results = [
|
||||
net.DequeueRebatchingQueue([queue], 1, num_elements=5),
|
||||
net.DequeueRebatchingQueue([queue], 1, num_elements=5),
|
||||
]
|
||||
|
||||
workspace.RunNetOnce(net)
|
||||
|
||||
npt.assert_array_equal(
|
||||
workspace.FetchBlob(results[0]), workspace.FetchBlob("tensors")[:5]
|
||||
)
|
||||
npt.assert_array_equal(
|
||||
workspace.FetchBlob(results[1]), workspace.FetchBlob("tensors")[5:]
|
||||
)
|
||||
|
||||
def test_rebatching_queue_closes_properly(self):
|
||||
net = core.Net('net')
|
||||
workspace.FeedBlob(
|
||||
"tensors", np.array([x for x in range(10)], np.int32)
|
||||
)
|
||||
|
||||
queue = net.CreateRebatchingQueue([], 1, capacity=10, num_blobs=1)
|
||||
|
||||
net.EnqueueRebatchingQueue([queue, "tensors"], 0, enqueue_batch=True)
|
||||
|
||||
net.CloseRebatchingQueue([queue], 0)
|
||||
|
||||
results = [
|
||||
net.DequeueRebatchingQueue([queue], 1, num_elements=5),
|
||||
net.DequeueRebatchingQueue([queue], 1, num_elements=5),
|
||||
]
|
||||
|
||||
workspace.RunNetOnce(net)
|
||||
|
||||
npt.assert_array_equal(
|
||||
workspace.FetchBlob(results[0]), workspace.FetchBlob("tensors")[:5]
|
||||
)
|
||||
npt.assert_array_equal(
|
||||
workspace.FetchBlob(results[1]), workspace.FetchBlob("tensors")[5:]
|
||||
)
|
||||
|
||||
# Enqueuing more should fail now since the queue is closed
|
||||
net.EnqueueRebatchingQueue([queue, "tensors"], [], enqueue_batch=True)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
workspace.RunNetOnce(net)
|
||||
|
||||
# Dequeuing more should fail now since the queue is closed
|
||||
results = [
|
||||
net.DequeueRebatchingQueue([queue], 1, num_elements=5),
|
||||
]
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
workspace.RunNetOnce(net)
|
||||
|
||||
def test_rebatching_queue_multiple_components(self):
|
||||
NUM_BLOBS = 4
|
||||
NUM_ELEMENTS = 10
|
||||
|
||||
net = core.Net('net')
|
||||
|
||||
workspace.blobs['complex_tensor'] = np.array(
|
||||
[[x, x + 1] for x in range(NUM_ELEMENTS)], dtype=np.int32
|
||||
)
|
||||
|
||||
tensors = [
|
||||
net.GivenTensorIntFill(
|
||||
[],
|
||||
1,
|
||||
shape=[NUM_ELEMENTS],
|
||||
values=[x for x in range(NUM_ELEMENTS)]
|
||||
),
|
||||
net.GivenTensorFill(
|
||||
[],
|
||||
1,
|
||||
shape=[NUM_ELEMENTS],
|
||||
values=[x * 1.0 for x in range(NUM_ELEMENTS)]
|
||||
),
|
||||
net.GivenTensorBoolFill(
|
||||
[],
|
||||
1,
|
||||
shape=[NUM_ELEMENTS],
|
||||
values=[(x % 2 == 0) for x in range(NUM_ELEMENTS)]
|
||||
),
|
||||
'complex_tensor',
|
||||
]
|
||||
|
||||
queue = net.CreateRebatchingQueue(
|
||||
[], 1, capacity=10, num_blobs=NUM_BLOBS
|
||||
)
|
||||
|
||||
net.EnqueueRebatchingQueue([queue] + tensors, [], enqueue_batch=True)
|
||||
|
||||
results = net.DequeueRebatchingQueue([queue], NUM_BLOBS, num_elements=5)
|
||||
|
||||
workspace.RunNetOnce(net)
|
||||
|
||||
for idx in range(NUM_BLOBS):
|
||||
npt.assert_array_equal(
|
||||
workspace.FetchBlob(results[idx]),
|
||||
workspace.FetchBlob(tensors[idx])[:5]
|
||||
)
|
||||
|
||||
@given(
|
||||
num_producers=st.integers(1, 5),
|
||||
num_consumers=st.integers(1, 5),
|
||||
producer_input_size=st.integers(1, 10),
|
||||
producer_num_iterations=st.integers(1, 10),
|
||||
capacity=st.integers(1, 10)
|
||||
)
|
||||
def test_rebatching_parallel_producer_consumer(
|
||||
self, num_producers, num_consumers, producer_input_size,
|
||||
producer_num_iterations, capacity
|
||||
):
|
||||
### Init ###
|
||||
total_inputs = producer_num_iterations * producer_input_size * num_producers
|
||||
inputs = []
|
||||
init_net = core.Net('init_net')
|
||||
queue = init_net.CreateRebatchingQueue(
|
||||
[], 1, capacity=capacity, num_blobs=1
|
||||
)
|
||||
|
||||
### Producers ###
|
||||
producer_steps = []
|
||||
for i in range(num_producers):
|
||||
name = 'producer_%d' % i
|
||||
net = core.Net(name)
|
||||
values = [
|
||||
producer_input_size * i + x for x in range(producer_input_size)
|
||||
]
|
||||
for _ in range(producer_num_iterations):
|
||||
inputs.extend(values)
|
||||
tensors = net.GivenTensorIntFill(
|
||||
[], 1, shape=[producer_input_size], values=values
|
||||
)
|
||||
|
||||
net.EnqueueRebatchingQueue([queue, tensors], [], enqueue_batch=True)
|
||||
|
||||
step = core.execution_step(
|
||||
name, net, num_iter=producer_num_iterations
|
||||
)
|
||||
producer_steps.append(step)
|
||||
|
||||
producer_step = core.execution_step(
|
||||
'producer', [
|
||||
core.execution_step(
|
||||
'producers', producer_steps, concurrent_substeps=True
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
### Consumers ###
|
||||
outputs = []
|
||||
|
||||
def append(ins, outs):
|
||||
# Extend is atomic
|
||||
outputs.extend(ins[0].data.tolist())
|
||||
|
||||
consumer_steps = []
|
||||
for i in range(num_consumers):
|
||||
# This is just a way of deterministally read all the elements.
|
||||
# We make `num_consumers` almost equal splits
|
||||
# (the reminder goes to the last consumer).
|
||||
num_elements_to_read = total_inputs // num_consumers
|
||||
if i == num_consumers - 1:
|
||||
num_elements_to_read = num_elements_to_read \
|
||||
+ total_inputs % num_consumers
|
||||
|
||||
# If we have nothing to read this consumer will be idle
|
||||
if (num_elements_to_read == 0):
|
||||
continue
|
||||
|
||||
# Now we have to make a split on number of iterations and the read
|
||||
# size for each iteration. This is again just one of many
|
||||
# deterministic ways of doing it. We factorize the total number of
|
||||
# elements we have to read and assign half of the factors to the
|
||||
# iterations half to the read size.
|
||||
factors = list(primefac(num_elements_to_read))
|
||||
|
||||
num_elements_per_iteration = functools.reduce(
|
||||
lambda x, y: x * y, factors[len(factors) // 2:], 1
|
||||
)
|
||||
|
||||
num_iterations = functools.reduce(
|
||||
lambda x, y: x * y, factors[:len(factors) // 2], 1
|
||||
)
|
||||
|
||||
name = 'consumer_%d' % i
|
||||
net = core.Net(name)
|
||||
blobs = net.DequeueRebatchingQueue(
|
||||
[queue], 1, num_elements=num_elements_per_iteration
|
||||
)
|
||||
net.Python(append)([blobs], 0)
|
||||
consumer_steps.append(
|
||||
core.execution_step(name, net, num_iter=num_iterations)
|
||||
)
|
||||
|
||||
consumer_step = core.execution_step(
|
||||
'consumer', consumer_steps, concurrent_substeps=True
|
||||
)
|
||||
|
||||
init_step = core.execution_step('init', init_net)
|
||||
worker_step = core.execution_step(
|
||||
'worker', [consumer_step, producer_step], concurrent_substeps=True
|
||||
)
|
||||
|
||||
### Execute Plan ###
|
||||
plan = core.Plan('test')
|
||||
plan.AddStep(init_step)
|
||||
plan.AddStep(worker_step)
|
||||
|
||||
self.ws.run(plan)
|
||||
|
||||
### Check Results ###
|
||||
# We check that the outputs are a permutation of inputs
|
||||
inputs.sort()
|
||||
outputs.sort()
|
||||
self.assertEquals(inputs, outputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
unittest.main()
|
||||
232
caffe2/queue/rebatching_queue.cc
Normal file
232
caffe2/queue/rebatching_queue.cc
Normal file
|
|
@ -0,0 +1,232 @@
|
|||
#include "rebatching_queue.h"
|
||||
#include "caffe2/utils/smart_tensor_printer.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
namespace {
|
||||
|
||||
// This concat function will always create a new first dimension to concat
|
||||
void concat(
|
||||
CPUContext& context,
|
||||
const std::vector<std::vector<TensorCPU>>& inputs,
|
||||
const std::vector<TensorCPU*>& outputs) {
|
||||
CAFFE_ENFORCE(!inputs.empty());
|
||||
|
||||
const auto& inputZero = inputs[0];
|
||||
const auto numTensors = inputZero.size();
|
||||
const auto numRows = inputs.size();
|
||||
|
||||
// Precompute the output sizes to avoid resizing
|
||||
std::vector<std::vector<TIndex>> outputDims(numTensors);
|
||||
|
||||
for (int i = 0; i < numTensors; ++i) {
|
||||
SmartTensorPrinter::PrintTensor(inputZero.at(i));
|
||||
outputDims[i] = inputZero.at(i).dims();
|
||||
outputDims[i].insert(outputDims[i].begin(), numRows);
|
||||
}
|
||||
|
||||
// Resize to the final output size
|
||||
std::vector<void*> destinations(numTensors);
|
||||
for (int i = 0; i < numTensors; ++i) {
|
||||
outputs[i]->Resize(outputDims[i]);
|
||||
destinations[i] = outputs[i]->raw_mutable_data(inputZero[i].meta());
|
||||
}
|
||||
|
||||
for (int i = 0; i < numRows; ++i) {
|
||||
CAFFE_ENFORCE_EQ(inputs[i].size(), numTensors);
|
||||
|
||||
for (int j = 0; j < numTensors; ++j) {
|
||||
const auto& input = inputs[i][j];
|
||||
|
||||
CAFFE_ENFORCE(inputZero[j].meta() == input.meta());
|
||||
CAFFE_ENFORCE_EQ(inputZero[j].itemsize(), input.itemsize());
|
||||
CAFFE_ENFORCE_EQ(inputZero[j].ndim(), input.ndim());
|
||||
for (int k = 0; k < input.ndim(); ++k) {
|
||||
CAFFE_ENFORCE_EQ(input.dims()[k], inputZero[j].dims()[k]);
|
||||
}
|
||||
|
||||
// Skip empty tensors
|
||||
if (input.size() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
context.CopyItems<CPUContext, CPUContext>(
|
||||
input.meta(),
|
||||
input.size(),
|
||||
input.raw_data() /* src */,
|
||||
destinations[j] /* dst */
|
||||
);
|
||||
|
||||
destinations[j] =
|
||||
(char*)destinations[j] + input.size() * input.itemsize();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto split(CPUContext& context, const std::vector<const TensorCPU*>& inputs) {
|
||||
CAFFE_ENFORCE(!inputs.empty());
|
||||
|
||||
const auto outputSize = inputs[0]->dims().at(0);
|
||||
std::vector<std::vector<TensorCPU>> outputs(outputSize);
|
||||
|
||||
for (const auto* inputPtr : inputs) {
|
||||
CAFFE_ENFORCE(inputPtr);
|
||||
|
||||
const auto& input = *inputPtr;
|
||||
const auto innerSize = input.size_from_dim(1);
|
||||
const auto itemSize = input.meta().itemsize();
|
||||
|
||||
auto outputDims = input.dims();
|
||||
CAFFE_ENFORCE(!outputDims.empty());
|
||||
outputDims.erase(outputDims.begin());
|
||||
CAFFE_ENFORCE_EQ(input.dims().at(0), outputSize);
|
||||
|
||||
for (int i = 0; i < outputSize; ++i) {
|
||||
outputs[i].push_back(TensorCPU(outputDims));
|
||||
context.CopyItems<CPUContext, CPUContext>(
|
||||
input.meta(),
|
||||
innerSize,
|
||||
(char*)input.raw_data() + i * innerSize * itemSize /* src */,
|
||||
outputs[i].back().raw_mutable_data(input.meta()) /* dst */);
|
||||
}
|
||||
}
|
||||
|
||||
return outputs;
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
RebatchingQueue::RebatchingQueue(size_t capacity, size_t numBlobs)
|
||||
: capacity_(capacity), numBlobs_(numBlobs), queue_(capacity) {}
|
||||
|
||||
RebatchingQueue::~RebatchingQueue() {
|
||||
close();
|
||||
}
|
||||
|
||||
bool RebatchingQueue::canRead() const {
|
||||
return tail_ < head_;
|
||||
}
|
||||
|
||||
bool RebatchingQueue::dequeue(
|
||||
CPUContext& context,
|
||||
size_t numElements,
|
||||
const std::vector<TensorCPU*>& outputs) {
|
||||
std::vector<std::vector<TensorCPU>> results;
|
||||
results.reserve(numElements);
|
||||
|
||||
for (;;) {
|
||||
if (results.size() == numElements) {
|
||||
break;
|
||||
}
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
|
||||
cvEmpty_.wait(lock, [this] { return canRead() || isClosed_; });
|
||||
|
||||
// We only want to stop reading if the queue is empty and closed
|
||||
if (!canRead() && isClosed_) {
|
||||
break;
|
||||
}
|
||||
|
||||
do {
|
||||
results.push_back(std::move(queue_[tail_++ % capacity()]));
|
||||
} while (canRead() && results.size() < numElements);
|
||||
}
|
||||
|
||||
if (numElements == 1) {
|
||||
cvOverflow_.notify_one();
|
||||
} else {
|
||||
cvOverflow_.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
if (results.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
concat(context, results, outputs);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RebatchingQueue::canWrite() const {
|
||||
return tail_ + capacity() > head_;
|
||||
}
|
||||
|
||||
bool RebatchingQueue::enqueueOne(
|
||||
CPUContext& context,
|
||||
const std::vector<const TensorCPU*>& inputs) {
|
||||
std::vector<std::vector<TensorCPU>> splittedInputs;
|
||||
splittedInputs.emplace_back();
|
||||
auto& tensorVector = splittedInputs.back();
|
||||
tensorVector.reserve(inputs.size());
|
||||
for (const auto* tensorPtr : inputs) {
|
||||
tensorVector.push_back(*tensorPtr);
|
||||
}
|
||||
|
||||
return enqueue(std::move(splittedInputs));
|
||||
}
|
||||
|
||||
bool RebatchingQueue::enqueueMany(
|
||||
CPUContext& context,
|
||||
const std::vector<const TensorCPU*>& inputs) {
|
||||
CAFFE_ENFORCE_EQ(numBlobs_, inputs.size());
|
||||
|
||||
std::vector<std::vector<TensorCPU>> splittedInputs;
|
||||
splittedInputs = split(context, inputs);
|
||||
return enqueue(std::move(splittedInputs));
|
||||
}
|
||||
|
||||
bool RebatchingQueue::enqueue(
|
||||
std::vector<std::vector<TensorCPU>> splittedInputs) {
|
||||
int idx = 0;
|
||||
for (;;) {
|
||||
if (idx >= splittedInputs.size()) {
|
||||
break;
|
||||
}
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
|
||||
cvOverflow_.wait(lock, [this] { return canWrite() || isClosed_; });
|
||||
|
||||
if (isClosed_) {
|
||||
// If we are here it means that we didn't apply the entire batch and if
|
||||
// we get closed in the middle of enquing we treat it as a non-success.
|
||||
return false;
|
||||
}
|
||||
|
||||
do {
|
||||
queue_[head_++ % capacity()] = std::move(splittedInputs[idx++]);
|
||||
} while (canWrite() && idx < splittedInputs.size());
|
||||
}
|
||||
|
||||
cvEmpty_.notify_all();
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t RebatchingQueue::capacity() const {
|
||||
return capacity_;
|
||||
}
|
||||
|
||||
size_t RebatchingQueue::numBlobs() const {
|
||||
return numBlobs_;
|
||||
}
|
||||
|
||||
bool RebatchingQueue::isClosed() const {
|
||||
std::lock_guard<std::mutex> g(mutex_);
|
||||
return isClosed_;
|
||||
}
|
||||
|
||||
void RebatchingQueue::close() {
|
||||
{
|
||||
std::lock_guard<std::mutex> g(mutex_);
|
||||
isClosed_ = true;
|
||||
}
|
||||
|
||||
cvEmpty_.notify_all();
|
||||
cvOverflow_.notify_all();
|
||||
}
|
||||
} // caffe2
|
||||
68
caffe2/queue/rebatching_queue.h
Normal file
68
caffe2/queue/rebatching_queue.h
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/core/stats.h"
|
||||
#include "caffe2/core/tensor.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
// TODO: This is a very naive implementation with a single mutex. We can do the
|
||||
// atomic index + circular queue optimizations or pull something more
|
||||
// heavy-weight later
|
||||
|
||||
class RebatchingQueue {
|
||||
public:
|
||||
RebatchingQueue(size_t capacity, size_t numBlobs);
|
||||
|
||||
~RebatchingQueue();
|
||||
|
||||
bool enqueueOne(
|
||||
CPUContext& context,
|
||||
const std::vector<const TensorCPU*>& inputs);
|
||||
|
||||
bool enqueueMany(
|
||||
CPUContext& context,
|
||||
const std::vector<const TensorCPU*>& inputs);
|
||||
|
||||
bool dequeue(
|
||||
CPUContext& context,
|
||||
size_t numElements,
|
||||
const std::vector<TensorCPU*>& outputs);
|
||||
|
||||
size_t capacity() const;
|
||||
|
||||
size_t numBlobs() const;
|
||||
|
||||
bool isClosed() const;
|
||||
|
||||
void close();
|
||||
|
||||
private:
|
||||
bool enqueue(std::vector<std::vector<TensorCPU>> splittedInputs);
|
||||
|
||||
bool canWrite() const;
|
||||
bool canRead() const;
|
||||
|
||||
const size_t capacity_;
|
||||
const size_t numBlobs_;
|
||||
|
||||
mutable std::mutex mutex_;
|
||||
|
||||
bool isClosed_{false};
|
||||
|
||||
uint64_t head_{0};
|
||||
uint64_t tail_{0};
|
||||
|
||||
std::condition_variable cvEmpty_;
|
||||
std::condition_variable cvOverflow_;
|
||||
|
||||
std::vector<std::vector<TensorCPU>> queue_;
|
||||
};
|
||||
} // caffe2
|
||||
73
caffe2/queue/rebatching_queue_ops.cc
Normal file
73
caffe2/queue/rebatching_queue_ops.cc
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
#include "rebatching_queue_ops.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
CAFFE_KNOWN_TYPE(RebatchingQueuePtr);
|
||||
|
||||
namespace {
|
||||
|
||||
REGISTER_CPU_OPERATOR(CreateRebatchingQueue, CreateRebatchingQueueOp);
|
||||
REGISTER_CPU_OPERATOR(EnqueueRebatchingQueue, EnqueueRebatchingQueueOp);
|
||||
REGISTER_CPU_OPERATOR(DequeueRebatchingQueue, DequeueRebatchingQueueOp);
|
||||
REGISTER_CPU_OPERATOR(CloseRebatchingQueue, CloseRebatchingQueueOp);
|
||||
|
||||
NO_GRADIENT(CreateRebatchingQueue);
|
||||
NO_GRADIENT(EnqueueRebatchingQueue);
|
||||
NO_GRADIENT(DequeueRebatchingQueue);
|
||||
NO_GRADIENT(CloseRebatchingQueue);
|
||||
|
||||
OPERATOR_SCHEMA(CreateRebatchingQueue)
|
||||
.NumInputs(0)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Creates the Queue.
|
||||
)DOC")
|
||||
.Output(0, "queue", "object representing the queue")
|
||||
.Arg("num_blobs", "Number of input tensors the queue will support")
|
||||
.Arg(
|
||||
"capacity",
|
||||
"Maximal number of elements the queue can hold at any given point");
|
||||
|
||||
OPERATOR_SCHEMA(CloseRebatchingQueue)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(0)
|
||||
.SetDoc(R"DOC(
|
||||
Closes the Queue.
|
||||
)DOC")
|
||||
.Input(0, "queue", "object representing the queue");
|
||||
|
||||
OPERATOR_SCHEMA(EnqueueRebatchingQueue)
|
||||
.NumInputs(2, INT_MAX)
|
||||
.NumOutputs(0)
|
||||
.SetDoc(R"DOC(
|
||||
Enqueues Tensors into the queue.
|
||||
Number of input tensors should be equal to the number of components passed
|
||||
during creation of the queue.
|
||||
If the Queue is closed this operation will fail.
|
||||
If enqueue_batch argument is set. We will split the input tensors by the
|
||||
first dimension to produce single queue elements.
|
||||
)DOC")
|
||||
.Input(0, "queue", "object representing the queue")
|
||||
.Input(1, "tensor", "First tensor to enque. ")
|
||||
.Arg(
|
||||
"enqueue_batch",
|
||||
"Are we enqueuing a batch or just a single element. \
|
||||
By default we enqueue single element.");
|
||||
|
||||
OPERATOR_SCHEMA(DequeueRebatchingQueue)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1, INT_MAX)
|
||||
.SetDoc(R"DOC(
|
||||
Dequeue Tensors from the Queue.
|
||||
If the Queue is closed this might return less elements than asked.
|
||||
If num_elements > 1 the returned elements will be concatenated into one
|
||||
tensor per component.
|
||||
|
||||
)DOC")
|
||||
.Input(0, "rebatching_queue", "object representing the queue")
|
||||
.Input(1, "tensor", "First tensor to enqueue")
|
||||
.Arg(
|
||||
"num_elements",
|
||||
"Number of elements to dequeue. By default we dequeue one element.");
|
||||
}
|
||||
}
|
||||
83
caffe2/queue/rebatching_queue_ops.h
Normal file
83
caffe2/queue/rebatching_queue_ops.h
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
#pragma once
|
||||
|
||||
#include "rebatching_queue.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
using RebatchingQueuePtr = std::unique_ptr<RebatchingQueue>;
|
||||
|
||||
class CreateRebatchingQueueOp : public Operator<CPUContext> {
|
||||
public:
|
||||
CreateRebatchingQueueOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator(operator_def, ws) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
*OperatorBase::Output<RebatchingQueuePtr>(0) =
|
||||
RebatchingQueuePtr(new RebatchingQueue(
|
||||
OperatorBase::GetSingleArgument<int>("capacity", 1),
|
||||
OperatorBase::GetSingleArgument<int>("num_blobs", 1)));
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
class EnqueueRebatchingQueueOp : public Operator<CPUContext> {
|
||||
public:
|
||||
EnqueueRebatchingQueueOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator(operator_def, ws),
|
||||
enqueueBatch_(
|
||||
OperatorBase::GetSingleArgument<bool>("enqueue_batch", false)) {}
|
||||
bool RunOnDevice() override {
|
||||
auto& queue = Inputs()[0]->template Get<RebatchingQueuePtr>();
|
||||
CHECK(queue);
|
||||
CAFFE_ENFORCE_EQ(InputSize(), queue->numBlobs() + 1);
|
||||
std::vector<const TensorCPU*> inputTensors;
|
||||
inputTensors.reserve(InputSize() - 1);
|
||||
for (int i = 1; i < InputSize(); ++i) {
|
||||
inputTensors.push_back(&Input(i));
|
||||
}
|
||||
|
||||
return enqueueBatch_ ? queue->enqueueMany(context_, inputTensors)
|
||||
: queue->enqueueOne(context_, inputTensors);
|
||||
}
|
||||
|
||||
private:
|
||||
const bool enqueueBatch_;
|
||||
};
|
||||
|
||||
class DequeueRebatchingQueueOp : public Operator<CPUContext> {
|
||||
public:
|
||||
DequeueRebatchingQueueOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator(operator_def, ws),
|
||||
numElements_(OperatorBase::GetSingleArgument<int>("num_elements", 1)) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
auto& queue = Inputs()[0]->template Get<RebatchingQueuePtr>();
|
||||
CHECK(queue);
|
||||
|
||||
std::vector<TensorCPU*> outputTensors;
|
||||
outputTensors.reserve(OutputSize());
|
||||
for (int i = 0; i < OutputSize(); ++i) {
|
||||
outputTensors.push_back(Output(i));
|
||||
}
|
||||
|
||||
return queue->dequeue(context_, numElements_, outputTensors);
|
||||
}
|
||||
|
||||
private:
|
||||
int numElements_;
|
||||
};
|
||||
|
||||
class CloseRebatchingQueueOp : public Operator<CPUContext> {
|
||||
public:
|
||||
CloseRebatchingQueueOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator(operator_def, ws) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
CAFFE_ENFORCE_EQ(InputSize(), 1);
|
||||
auto& queue = Inputs()[0]->template Get<RebatchingQueuePtr>();
|
||||
CAFFE_ENFORCE(queue);
|
||||
queue->close();
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // caffe2
|
||||
Loading…
Reference in a new issue