Add new LengthsSplit operator (#10974)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10974

Pull Request resolved: https://github.com/pytorch/pytorch/pull/10291

This new operator will do the following:

Given a LENGTHS vector and n_splits, output a "split" LENGTHS vector where:

1. Each length in input vector is split into n_splits values (thus output vector should have LENGTHS.size(0) * n_splits elements)
2. The new lengths in output should be evenly split, and if the length is not divisible by n_splits, then order new values in descending order. (e.g. n_splits = 3, length = 5 -> 2 2 1)
3. If n_splits > some element in the array, its split elements will contain 0s. (e.g. n_splits = 3, length = 2 - > 1 1 0)

Reviewed By: bddppq, chocjy

Differential Revision: D9013119

fbshipit-source-id: 82bf3371ec08c41fc3379177f0007afc142e0d84
This commit is contained in:
Mingda Li 2018-09-10 15:39:01 -07:00 committed by Facebook Github Bot
parent 0b78ae86c5
commit f2f43ad2da
3 changed files with 263 additions and 0 deletions

View file

@ -0,0 +1,37 @@
#include "caffe2/operators/length_split_op.h"
namespace caffe2 {
REGISTER_CPU_OPERATOR(LengthsSplit, LengthsSplitOp<CPUContext>);
OPERATOR_SCHEMA(LengthsSplit)
.NumInputs(1, 2)
.NumOutputs(1)
.ScalarType(TensorProto::INT32)
.SetDoc(R"DOC(
Given input vector LENGTHS, and input n_split, LengthsSplit returns
a single output vector. It "splits" each length into n_split values which add
up to the original length. It will attempt to do equal splits, and if not possible,
it orders larger values first. If the n_split is larger than the length, zero
padding will be applied.
e.g. LENGTHS = [9 4 5]
n_split = 3
Y = [3 3 3 2 1 1 2 2 1]
e.g. LENGTHS = [2, 1, 2]
n_split = 3
Y = [1 1 0 1 0 0 1 1 0]
)DOC")
.Arg("n_split", "Number of splits for each element in LENGTHS")
.Input(0, "LENGTHS", "Mx1 Input tensor denoting INT32 lengths")
.Input(
1,
"n_split",
"(Optional) Number of splits for each element in LENGTHS (overrides argument)")
.Output(0, "Y", "(M*n_split)x1 Output vector denoting split lengths");
// TODO: Write gradient for this when needed
GRADIENT_NOT_IMPLEMENTED_YET(LengthsSplit);
} // namespace caffe2

View file

@ -0,0 +1,75 @@
#ifndef CAFFE2_OPERATORS_LENGTH_SPLIT_OP_H_
#define CAFFE2_OPERATORS_LENGTH_SPLIT_OP_H_
#include "caffe2/core/common_omp.h"
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <class Context>
class LengthsSplitOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
LengthsSplitOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
n_split_(OperatorBase::GetSingleArgument<int32_t>("n_split", 0)) {
if (InputSize() == 1) {
// If not specified, then must have this argument
CAFFE_ENFORCE(
OperatorBase::HasArgument("n_split"),
"Argument `n_split` is missing and was not specified as input.");
CAFFE_ENFORCE(
n_split_ > 0,
"`n_split` must contain a positive value for defined behavior.");
}
}
~LengthsSplitOp() {}
bool RunOnDevice() override {
const auto& L = Input(0);
CAFFE_ENFORCE_EQ(L.ndim(), 1, "Input `LENGTHS` should be a 1D vector.");
if (InputSize() > 1) {
// We potentially have n_split specified as inputs as well
CAFFE_ENFORCE(
Input(1).ndim() == 1 && Input(1).size() == 1,
"Input `n_split` should be a vector of size 1.");
const auto& input1 = Input(1);
context_.template CopyItems<Context, CPUContext>(
input1.meta(), 1, input1.raw_data(), &n_split_);
}
CAFFE_ENFORCE(
n_split_ > 0,
"`n_split` must contain a positive value for defined behavior.");
const auto M = L.size();
auto* Y = Output(0);
Y->Resize(M * n_split_);
const int32_t* Ldata = L.template data<int32_t>();
int32_t* Ydata = Y->template mutable_data<int32_t>();
for (int i = 0; i < M; i++) {
int32_t mod = Ldata[i] % n_split_;
int32_t res =
mod != 0 ? math::divUp(Ldata[i], n_split_) : Ldata[i] / n_split_ + 1;
for (int j = 0; j < n_split_; j++) {
Ydata[(i * n_split_) + j] = mod-- > 0 ? res : res - 1;
}
}
return true;
}
private:
int32_t n_split_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_LENGTH_SPLIT_OP_H_

View file

@ -0,0 +1,151 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core
from hypothesis import given
import caffe2.python.hypothesis_test_util as hu
import hypothesis.strategies as st
import numpy as np
class TestLengthSplitOperator(hu.HypothesisTestCase):
def _length_split_op_ref(self, input_lengths, n_split_array):
output = []
n_split = n_split_array[0]
for x in input_lengths:
mod = x % n_split
val = x // n_split + 1
for _ in range(n_split):
if mod > 0:
output.append(val)
mod -= 1
else:
output.append(val - 1)
return [np.array(output).astype(np.int32)]
@given(**hu.gcs_cpu_only)
def test_length_split_edge(self, gc, dc):
input_lengths = np.array([3, 4, 5]).astype(np.int32)
n_split_ = np.array([5]).astype(np.int32)
# Expected output:
# [1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1]
op = core.CreateOperator(
'LengthsSplit',
['input_lengths',
'n_split'],
['Y'],
)
# Check against numpy reference
self.assertReferenceChecks(
device_option=gc,
op=op,
inputs=[input_lengths,
n_split_],
reference=self._length_split_op_ref,
)
# Check over multiple devices
self.assertDeviceChecks(dc, op, [input_lengths, n_split_], [0])
@given(**hu.gcs_cpu_only)
def test_length_split_arg(self, gc, dc):
input_lengths = np.array([9, 4, 5]).astype(np.int32)
n_split = 3
# Expected output:
# [3, 3, 3, 2, 1, 1, 2, 2, 1]
op = core.CreateOperator(
'LengthsSplit',
['input_lengths'],
['Y'], n_split=n_split
)
# Check against numpy reference
self.assertReferenceChecks(
device_option=gc,
op=op,
inputs=[input_lengths],
reference=lambda x : self._length_split_op_ref(x, [n_split]),
)
# Check over multiple devices
self.assertDeviceChecks(dc, op, [input_lengths], [0])
@given(**hu.gcs_cpu_only)
def test_length_split_override_arg(self, gc, dc):
input_lengths = np.array([9, 4, 5]).astype(np.int32)
n_split_ignored = 2
n_split_used = np.array([3]).astype(np.int32)
op = core.CreateOperator(
'LengthsSplit',
['input_lengths',
'n_split'],
['Y'], n_split=n_split_ignored
)
# Check against numpy reference
self.assertReferenceChecks(
device_option=gc,
op=op,
inputs=[input_lengths,
n_split_used],
reference=self._length_split_op_ref,
)
# Check over multiple devices
self.assertDeviceChecks(dc, op, [input_lengths, n_split_used], [0])
@given(m=st.integers(1, 100), n_split=st.integers(1, 20),
**hu.gcs_cpu_only)
def test_length_split_even_divide(self, m, n_split, gc, dc):
# multiples of n_split
input_lengths = np.random.randint(100, size=m).astype(np.int32) * n_split
n_split_ = np.array([n_split]).astype(np.int32)
op = core.CreateOperator(
'LengthsSplit',
['input_lengths',
'n_split'],
['Y'],
)
# Check against numpy reference
self.assertReferenceChecks(
device_option=gc,
op=op,
inputs=[input_lengths,
n_split_],
reference=self._length_split_op_ref,
)
# Check over multiple devices
self.assertDeviceChecks(dc, op, [input_lengths, n_split_], [0])
@given(m=st.integers(1, 100), n_split=st.integers(1, 20),
**hu.gcs_cpu_only)
def test_length_split_random(self, m, n_split, gc, dc):
input_lengths = np.random.randint(100, size=m).astype(np.int32)
n_split_ = np.array([n_split]).astype(np.int32)
op = core.CreateOperator(
'LengthsSplit',
['input_lengths',
'n_split'],
['Y'],
)
# Check against numpy reference
self.assertReferenceChecks(
device_option=gc,
op=op,
inputs=[input_lengths,
n_split_],
reference=self._length_split_op_ref,
)
# Check over multiple devices
self.assertDeviceChecks(dc, op, [input_lengths, n_split_], [0])
if __name__ == "__main__":
import unittest
unittest.main()