mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add AdjustBatch Op (#16676)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16676 This op is used for changing batch size (first dimension) of the tensor. Reviewed By: bertmaher, ipiszy Differential Revision: D13929200 fbshipit-source-id: 4f2c3faec072d468be8301bf00c80d33adb3b5b3
This commit is contained in:
parent
100aa0798e
commit
e5e0bf4152
3 changed files with 170 additions and 0 deletions
20
caffe2/operators/adjust_batch_op.cc
Normal file
20
caffe2/operators/adjust_batch_op.cc
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
#include "caffe2/operators/adjust_batch_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
REGISTER_CPU_OPERATOR(AdjustBatch, AdjustBatchOp<CPUContext>);
|
||||
OPERATOR_SCHEMA(AdjustBatch)
|
||||
.NumInputs(1, 2)
|
||||
.NumOutputs(1, 2)
|
||||
.Input(0, "Input", "Input data")
|
||||
.Input(1, "RealBatchSizeIn", "[Optional] Real batch size")
|
||||
.Output(0, "Output", "Data with Adjusted batch size")
|
||||
.Output(1, "RealBatchSizeOut", "[Optional] Real batah size")
|
||||
.Arg("max_batch_size", "(*int*): max batch size")
|
||||
.SetDoc(R"DOC(
|
||||
Adjust the batch size of `input` tensor. When we only have 1 input, it will adjust the batch size according to `max_batch_size` argument. In this case, in addition, if it has two outputs, it will record the input batch size and record it to the second output. When we have 2 inputs, it expects the seocnd input contains the batch size to adjust to, and will truncate the input data accordingly.
|
||||
|
||||
Github Links:
|
||||
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/adjust_batch_op.cc
|
||||
|
||||
)DOC");
|
||||
} // namespace caffe2
|
||||
75
caffe2/operators/adjust_batch_op.h
Normal file
75
caffe2/operators/adjust_batch_op.h
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
#pragma once
|
||||
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <class Context>
|
||||
class AdjustBatchOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
AdjustBatchOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
max_batch_size_(
|
||||
this->template GetSingleArgument<int64_t>("max_batch_size", -1)) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
auto& input = Input(0);
|
||||
vector<int64_t> output_dims(input.sizes().vec());
|
||||
CAFFE_ENFORCE(!output_dims.empty());
|
||||
if (InputSize() > 1) {
|
||||
// TODO: if we have a second input and we have max_batch_size set, check
|
||||
// the batch size of the two inputs for consistency
|
||||
auto& batch_size = Input(1);
|
||||
int64_t real_batch_size = *batch_size.template data<int64_t>();
|
||||
int64_t max_batch_size = output_dims[0];
|
||||
CAFFE_ENFORCE_GE(max_batch_size, real_batch_size);
|
||||
output_dims[0] = real_batch_size;
|
||||
auto* output = Output(0, output_dims, input.dtype());
|
||||
this->context_.template CopyItems<Context, Context>(
|
||||
input.dtype(),
|
||||
input.numel() * real_batch_size / max_batch_size,
|
||||
input.raw_data(),
|
||||
output->raw_mutable_data(input.dtype()));
|
||||
} else {
|
||||
// Pad to max batch size
|
||||
CAFFE_ENFORCE_GT(
|
||||
max_batch_size_,
|
||||
0,
|
||||
"max_batch_size should be larger than 0. Got ",
|
||||
max_batch_size_);
|
||||
|
||||
// TODO: ideally we can support the case when input batch is larger than
|
||||
// the max_batch_size, as we can just pad to the multiple of
|
||||
// max_batch_size.
|
||||
CAFFE_ENFORCE_GE(max_batch_size_, output_dims.front());
|
||||
|
||||
int64_t real_batch_size = output_dims[0];
|
||||
output_dims[0] = max_batch_size_;
|
||||
auto* output = Output(0, output_dims, input.dtype());
|
||||
math::Set(
|
||||
output->nbytes(),
|
||||
static_cast<char>(0),
|
||||
static_cast<char*>(output->raw_data()),
|
||||
&context_);
|
||||
this->context_.template CopyItems<Context, Context>(
|
||||
input.dtype(),
|
||||
input.numel(),
|
||||
input.raw_data(),
|
||||
output->raw_mutable_data(input.dtype()));
|
||||
|
||||
if (OutputSize() > 1) {
|
||||
auto* real_batch_tensor = Output(1, {1}, at::dtype<int64_t>());
|
||||
real_batch_tensor->template mutable_data<int64_t>()[0] =
|
||||
real_batch_size;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t max_batch_size_;
|
||||
};
|
||||
} // namespace caffe2
|
||||
75
caffe2/python/operator_test/adjust_batch_op_test.py
Normal file
75
caffe2/python/operator_test/adjust_batch_op_test.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from caffe2.python import core, workspace
|
||||
from hypothesis import given, assume
|
||||
import caffe2.python.hypothesis_test_util as hu
|
||||
import hypothesis.strategies as st
|
||||
import numpy as np
|
||||
|
||||
import unittest
|
||||
import os
|
||||
|
||||
|
||||
class TestAdjustBatchOp(hu.HypothesisTestCase):
|
||||
@given(d=st.integers(1, 4), n=st.integers(1, 20),
|
||||
seed=st.integers(0, 1000), **hu.gcs_cpu_only)
|
||||
def test_pad(self, d, n, gc, dc, seed):
|
||||
for dtype in [np.float32, np.int8, np.int64]:
|
||||
np.random.seed(seed)
|
||||
dims = [n] * d
|
||||
X = np.random.rand(*dims).astype(dtype)
|
||||
max_batch_size = n + 8
|
||||
|
||||
def ref_op(X):
|
||||
shape = list(X.shape)
|
||||
out = np.zeros((1), dtype=np.int64)
|
||||
out[0] = shape[0]
|
||||
shape[0] = max_batch_size
|
||||
Y = np.zeros(shape, dtype=dtype)
|
||||
Y[:n] = X
|
||||
return [Y, out]
|
||||
|
||||
op = core.CreateOperator(
|
||||
"AdjustBatch",
|
||||
["X"],
|
||||
["Y", "RealBatch"],
|
||||
max_batch_size=max_batch_size,
|
||||
)
|
||||
|
||||
self.assertReferenceChecks(
|
||||
device_option=gc,
|
||||
op=op,
|
||||
inputs=[X],
|
||||
reference=ref_op,
|
||||
)
|
||||
|
||||
@given(d=st.integers(1, 4), n=st.integers(8, 20),
|
||||
seed=st.integers(0, 1000), **hu.gcs_cpu_only)
|
||||
def test_truncate(self, d, n, gc, dc, seed):
|
||||
for dtype in [np.float32, np.int8, np.int64]:
|
||||
np.random.seed(seed)
|
||||
dims = [n] * d
|
||||
X = np.random.rand(*dims).astype(dtype)
|
||||
real_batch_size = n - 8
|
||||
R = np.zeros((1), dtype=np.int64)
|
||||
R[0] = real_batch_size
|
||||
|
||||
def ref_op(X, R):
|
||||
r = R[0]
|
||||
return [X[:r]]
|
||||
|
||||
op = core.CreateOperator(
|
||||
"AdjustBatch",
|
||||
["X", "RealBatch"],
|
||||
["Y"],
|
||||
)
|
||||
|
||||
self.assertReferenceChecks(
|
||||
device_option=gc,
|
||||
op=op,
|
||||
inputs=[X, R],
|
||||
reference=ref_op,
|
||||
)
|
||||
Loading…
Reference in a new issue