Add AdjustBatch Op (#16676)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16676

This op is used for changing batch size (first dimension) of the tensor.

Reviewed By: bertmaher, ipiszy

Differential Revision: D13929200

fbshipit-source-id: 4f2c3faec072d468be8301bf00c80d33adb3b5b3
This commit is contained in:
Yinghai Lu 2019-02-06 19:12:32 -08:00 committed by Facebook Github Bot
parent 100aa0798e
commit e5e0bf4152
3 changed files with 170 additions and 0 deletions

View file

@ -0,0 +1,20 @@
#include "caffe2/operators/adjust_batch_op.h"
namespace caffe2 {
REGISTER_CPU_OPERATOR(AdjustBatch, AdjustBatchOp<CPUContext>);
OPERATOR_SCHEMA(AdjustBatch)
.NumInputs(1, 2)
.NumOutputs(1, 2)
.Input(0, "Input", "Input data")
.Input(1, "RealBatchSizeIn", "[Optional] Real batch size")
.Output(0, "Output", "Data with Adjusted batch size")
.Output(1, "RealBatchSizeOut", "[Optional] Real batah size")
.Arg("max_batch_size", "(*int*): max batch size")
.SetDoc(R"DOC(
Adjust the batch size of `input` tensor. When we only have 1 input, it will adjust the batch size according to `max_batch_size` argument. In this case, in addition, if it has two outputs, it will record the input batch size and record it to the second output. When we have 2 inputs, it expects the seocnd input contains the batch size to adjust to, and will truncate the input data accordingly.
Github Links:
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/adjust_batch_op.cc
)DOC");
} // namespace caffe2

View file

@ -0,0 +1,75 @@
#pragma once
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
template <class Context>
class AdjustBatchOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
AdjustBatchOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
max_batch_size_(
this->template GetSingleArgument<int64_t>("max_batch_size", -1)) {}
bool RunOnDevice() override {
auto& input = Input(0);
vector<int64_t> output_dims(input.sizes().vec());
CAFFE_ENFORCE(!output_dims.empty());
if (InputSize() > 1) {
// TODO: if we have a second input and we have max_batch_size set, check
// the batch size of the two inputs for consistency
auto& batch_size = Input(1);
int64_t real_batch_size = *batch_size.template data<int64_t>();
int64_t max_batch_size = output_dims[0];
CAFFE_ENFORCE_GE(max_batch_size, real_batch_size);
output_dims[0] = real_batch_size;
auto* output = Output(0, output_dims, input.dtype());
this->context_.template CopyItems<Context, Context>(
input.dtype(),
input.numel() * real_batch_size / max_batch_size,
input.raw_data(),
output->raw_mutable_data(input.dtype()));
} else {
// Pad to max batch size
CAFFE_ENFORCE_GT(
max_batch_size_,
0,
"max_batch_size should be larger than 0. Got ",
max_batch_size_);
// TODO: ideally we can support the case when input batch is larger than
// the max_batch_size, as we can just pad to the multiple of
// max_batch_size.
CAFFE_ENFORCE_GE(max_batch_size_, output_dims.front());
int64_t real_batch_size = output_dims[0];
output_dims[0] = max_batch_size_;
auto* output = Output(0, output_dims, input.dtype());
math::Set(
output->nbytes(),
static_cast<char>(0),
static_cast<char*>(output->raw_data()),
&context_);
this->context_.template CopyItems<Context, Context>(
input.dtype(),
input.numel(),
input.raw_data(),
output->raw_mutable_data(input.dtype()));
if (OutputSize() > 1) {
auto* real_batch_tensor = Output(1, {1}, at::dtype<int64_t>());
real_batch_tensor->template mutable_data<int64_t>()[0] =
real_batch_size;
}
}
return true;
}
private:
int64_t max_batch_size_;
};
} // namespace caffe2

View file

@ -0,0 +1,75 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core, workspace
from hypothesis import given, assume
import caffe2.python.hypothesis_test_util as hu
import hypothesis.strategies as st
import numpy as np
import unittest
import os
class TestAdjustBatchOp(hu.HypothesisTestCase):
@given(d=st.integers(1, 4), n=st.integers(1, 20),
seed=st.integers(0, 1000), **hu.gcs_cpu_only)
def test_pad(self, d, n, gc, dc, seed):
for dtype in [np.float32, np.int8, np.int64]:
np.random.seed(seed)
dims = [n] * d
X = np.random.rand(*dims).astype(dtype)
max_batch_size = n + 8
def ref_op(X):
shape = list(X.shape)
out = np.zeros((1), dtype=np.int64)
out[0] = shape[0]
shape[0] = max_batch_size
Y = np.zeros(shape, dtype=dtype)
Y[:n] = X
return [Y, out]
op = core.CreateOperator(
"AdjustBatch",
["X"],
["Y", "RealBatch"],
max_batch_size=max_batch_size,
)
self.assertReferenceChecks(
device_option=gc,
op=op,
inputs=[X],
reference=ref_op,
)
@given(d=st.integers(1, 4), n=st.integers(8, 20),
seed=st.integers(0, 1000), **hu.gcs_cpu_only)
def test_truncate(self, d, n, gc, dc, seed):
for dtype in [np.float32, np.int8, np.int64]:
np.random.seed(seed)
dims = [n] * d
X = np.random.rand(*dims).astype(dtype)
real_batch_size = n - 8
R = np.zeros((1), dtype=np.int64)
R[0] = real_batch_size
def ref_op(X, R):
r = R[0]
return [X[:r]]
op = core.CreateOperator(
"AdjustBatch",
["X", "RealBatch"],
["Y"],
)
self.assertReferenceChecks(
device_option=gc,
op=op,
inputs=[X, R],
reference=ref_op,
)