mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add new LengthsSplit operator (#10974)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10974 Pull Request resolved: https://github.com/pytorch/pytorch/pull/10291 This new operator will do the following: Given a LENGTHS vector and n_splits, output a "split" LENGTHS vector where: 1. Each length in input vector is split into n_splits values (thus output vector should have LENGTHS.size(0) * n_splits elements) 2. The new lengths in output should be evenly split, and if the length is not divisible by n_splits, then order new values in descending order. (e.g. n_splits = 3, length = 5 -> 2 2 1) 3. If n_splits > some element in the array, its split elements will contain 0s. (e.g. n_splits = 3, length = 2 - > 1 1 0) Reviewed By: bddppq, chocjy Differential Revision: D9013119 fbshipit-source-id: 82bf3371ec08c41fc3379177f0007afc142e0d84
This commit is contained in:
parent
0b78ae86c5
commit
f2f43ad2da
3 changed files with 263 additions and 0 deletions
37
caffe2/operators/length_split_op.cc
Normal file
37
caffe2/operators/length_split_op.cc
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
#include "caffe2/operators/length_split_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
REGISTER_CPU_OPERATOR(LengthsSplit, LengthsSplitOp<CPUContext>);
|
||||
|
||||
OPERATOR_SCHEMA(LengthsSplit)
|
||||
.NumInputs(1, 2)
|
||||
.NumOutputs(1)
|
||||
.ScalarType(TensorProto::INT32)
|
||||
.SetDoc(R"DOC(
|
||||
Given input vector LENGTHS, and input n_split, LengthsSplit returns
|
||||
a single output vector. It "splits" each length into n_split values which add
|
||||
up to the original length. It will attempt to do equal splits, and if not possible,
|
||||
it orders larger values first. If the n_split is larger than the length, zero
|
||||
padding will be applied.
|
||||
|
||||
e.g. LENGTHS = [9 4 5]
|
||||
n_split = 3
|
||||
Y = [3 3 3 2 1 1 2 2 1]
|
||||
|
||||
e.g. LENGTHS = [2, 1, 2]
|
||||
n_split = 3
|
||||
Y = [1 1 0 1 0 0 1 1 0]
|
||||
)DOC")
|
||||
.Arg("n_split", "Number of splits for each element in LENGTHS")
|
||||
.Input(0, "LENGTHS", "Mx1 Input tensor denoting INT32 lengths")
|
||||
.Input(
|
||||
1,
|
||||
"n_split",
|
||||
"(Optional) Number of splits for each element in LENGTHS (overrides argument)")
|
||||
.Output(0, "Y", "(M*n_split)x1 Output vector denoting split lengths");
|
||||
|
||||
// TODO: Write gradient for this when needed
|
||||
GRADIENT_NOT_IMPLEMENTED_YET(LengthsSplit);
|
||||
|
||||
} // namespace caffe2
|
||||
75
caffe2/operators/length_split_op.h
Normal file
75
caffe2/operators/length_split_op.h
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
#ifndef CAFFE2_OPERATORS_LENGTH_SPLIT_OP_H_
|
||||
#define CAFFE2_OPERATORS_LENGTH_SPLIT_OP_H_
|
||||
|
||||
#include "caffe2/core/common_omp.h"
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <class Context>
|
||||
class LengthsSplitOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
LengthsSplitOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
n_split_(OperatorBase::GetSingleArgument<int32_t>("n_split", 0)) {
|
||||
if (InputSize() == 1) {
|
||||
// If not specified, then must have this argument
|
||||
CAFFE_ENFORCE(
|
||||
OperatorBase::HasArgument("n_split"),
|
||||
"Argument `n_split` is missing and was not specified as input.");
|
||||
CAFFE_ENFORCE(
|
||||
n_split_ > 0,
|
||||
"`n_split` must contain a positive value for defined behavior.");
|
||||
}
|
||||
}
|
||||
~LengthsSplitOp() {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
const auto& L = Input(0);
|
||||
CAFFE_ENFORCE_EQ(L.ndim(), 1, "Input `LENGTHS` should be a 1D vector.");
|
||||
|
||||
if (InputSize() > 1) {
|
||||
// We potentially have n_split specified as inputs as well
|
||||
CAFFE_ENFORCE(
|
||||
Input(1).ndim() == 1 && Input(1).size() == 1,
|
||||
"Input `n_split` should be a vector of size 1.");
|
||||
|
||||
const auto& input1 = Input(1);
|
||||
context_.template CopyItems<Context, CPUContext>(
|
||||
input1.meta(), 1, input1.raw_data(), &n_split_);
|
||||
}
|
||||
|
||||
CAFFE_ENFORCE(
|
||||
n_split_ > 0,
|
||||
"`n_split` must contain a positive value for defined behavior.");
|
||||
const auto M = L.size();
|
||||
|
||||
auto* Y = Output(0);
|
||||
Y->Resize(M * n_split_);
|
||||
|
||||
const int32_t* Ldata = L.template data<int32_t>();
|
||||
int32_t* Ydata = Y->template mutable_data<int32_t>();
|
||||
|
||||
for (int i = 0; i < M; i++) {
|
||||
int32_t mod = Ldata[i] % n_split_;
|
||||
int32_t res =
|
||||
mod != 0 ? math::divUp(Ldata[i], n_split_) : Ldata[i] / n_split_ + 1;
|
||||
for (int j = 0; j < n_split_; j++) {
|
||||
Ydata[(i * n_split_) + j] = mod-- > 0 ? res : res - 1;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
int32_t n_split_;
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_OPERATORS_LENGTH_SPLIT_OP_H_
|
||||
151
caffe2/python/operator_test/length_split_op_test.py
Normal file
151
caffe2/python/operator_test/length_split_op_test.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from caffe2.python import core
|
||||
from hypothesis import given
|
||||
import caffe2.python.hypothesis_test_util as hu
|
||||
import hypothesis.strategies as st
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestLengthSplitOperator(hu.HypothesisTestCase):
|
||||
|
||||
def _length_split_op_ref(self, input_lengths, n_split_array):
|
||||
output = []
|
||||
n_split = n_split_array[0]
|
||||
for x in input_lengths:
|
||||
mod = x % n_split
|
||||
val = x // n_split + 1
|
||||
for _ in range(n_split):
|
||||
if mod > 0:
|
||||
output.append(val)
|
||||
mod -= 1
|
||||
else:
|
||||
output.append(val - 1)
|
||||
return [np.array(output).astype(np.int32)]
|
||||
|
||||
@given(**hu.gcs_cpu_only)
|
||||
def test_length_split_edge(self, gc, dc):
|
||||
input_lengths = np.array([3, 4, 5]).astype(np.int32)
|
||||
n_split_ = np.array([5]).astype(np.int32)
|
||||
# Expected output:
|
||||
# [1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1]
|
||||
op = core.CreateOperator(
|
||||
'LengthsSplit',
|
||||
['input_lengths',
|
||||
'n_split'],
|
||||
['Y'],
|
||||
)
|
||||
|
||||
# Check against numpy reference
|
||||
self.assertReferenceChecks(
|
||||
device_option=gc,
|
||||
op=op,
|
||||
inputs=[input_lengths,
|
||||
n_split_],
|
||||
reference=self._length_split_op_ref,
|
||||
)
|
||||
# Check over multiple devices
|
||||
self.assertDeviceChecks(dc, op, [input_lengths, n_split_], [0])
|
||||
|
||||
@given(**hu.gcs_cpu_only)
|
||||
def test_length_split_arg(self, gc, dc):
|
||||
input_lengths = np.array([9, 4, 5]).astype(np.int32)
|
||||
n_split = 3
|
||||
# Expected output:
|
||||
# [3, 3, 3, 2, 1, 1, 2, 2, 1]
|
||||
op = core.CreateOperator(
|
||||
'LengthsSplit',
|
||||
['input_lengths'],
|
||||
['Y'], n_split=n_split
|
||||
)
|
||||
|
||||
# Check against numpy reference
|
||||
self.assertReferenceChecks(
|
||||
device_option=gc,
|
||||
op=op,
|
||||
inputs=[input_lengths],
|
||||
reference=lambda x : self._length_split_op_ref(x, [n_split]),
|
||||
)
|
||||
# Check over multiple devices
|
||||
self.assertDeviceChecks(dc, op, [input_lengths], [0])
|
||||
|
||||
@given(**hu.gcs_cpu_only)
|
||||
def test_length_split_override_arg(self, gc, dc):
|
||||
input_lengths = np.array([9, 4, 5]).astype(np.int32)
|
||||
n_split_ignored = 2
|
||||
n_split_used = np.array([3]).astype(np.int32)
|
||||
|
||||
op = core.CreateOperator(
|
||||
'LengthsSplit',
|
||||
['input_lengths',
|
||||
'n_split'],
|
||||
['Y'], n_split=n_split_ignored
|
||||
)
|
||||
|
||||
# Check against numpy reference
|
||||
self.assertReferenceChecks(
|
||||
device_option=gc,
|
||||
op=op,
|
||||
inputs=[input_lengths,
|
||||
n_split_used],
|
||||
reference=self._length_split_op_ref,
|
||||
)
|
||||
# Check over multiple devices
|
||||
self.assertDeviceChecks(dc, op, [input_lengths, n_split_used], [0])
|
||||
|
||||
@given(m=st.integers(1, 100), n_split=st.integers(1, 20),
|
||||
**hu.gcs_cpu_only)
|
||||
def test_length_split_even_divide(self, m, n_split, gc, dc):
|
||||
# multiples of n_split
|
||||
input_lengths = np.random.randint(100, size=m).astype(np.int32) * n_split
|
||||
n_split_ = np.array([n_split]).astype(np.int32)
|
||||
|
||||
op = core.CreateOperator(
|
||||
'LengthsSplit',
|
||||
['input_lengths',
|
||||
'n_split'],
|
||||
['Y'],
|
||||
)
|
||||
|
||||
# Check against numpy reference
|
||||
self.assertReferenceChecks(
|
||||
device_option=gc,
|
||||
op=op,
|
||||
inputs=[input_lengths,
|
||||
n_split_],
|
||||
reference=self._length_split_op_ref,
|
||||
)
|
||||
# Check over multiple devices
|
||||
self.assertDeviceChecks(dc, op, [input_lengths, n_split_], [0])
|
||||
|
||||
@given(m=st.integers(1, 100), n_split=st.integers(1, 20),
|
||||
**hu.gcs_cpu_only)
|
||||
def test_length_split_random(self, m, n_split, gc, dc):
|
||||
input_lengths = np.random.randint(100, size=m).astype(np.int32)
|
||||
n_split_ = np.array([n_split]).astype(np.int32)
|
||||
|
||||
op = core.CreateOperator(
|
||||
'LengthsSplit',
|
||||
['input_lengths',
|
||||
'n_split'],
|
||||
['Y'],
|
||||
)
|
||||
|
||||
# Check against numpy reference
|
||||
self.assertReferenceChecks(
|
||||
device_option=gc,
|
||||
op=op,
|
||||
inputs=[input_lengths,
|
||||
n_split_],
|
||||
reference=self._length_split_op_ref,
|
||||
)
|
||||
# Check over multiple devices
|
||||
self.assertDeviceChecks(dc, op, [input_lengths, n_split_], [0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
unittest.main()
|
||||
Loading…
Reference in a new issue