ReverseSequence contrib op (#728)

This commit is contained in:
Scott McKay 2019-03-29 17:25:37 +10:00 committed by Changming Sun
parent 333171f602
commit b9b6e3abcb
9 changed files with 528 additions and 33 deletions

View file

@ -28,6 +28,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ConvI
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ROIAlign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, ROIAlign);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearConv);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReverseSequence);
void RegisterContribKernels(KernelRegistry& kernel_registry) {
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp)>());
@ -54,6 +55,7 @@ void RegisterContribKernels(KernelRegistry& kernel_registry) {
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ROIAlign)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, ROIAlign)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearConv)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReverseSequence)>());
}
} // namespace contrib

View file

@ -0,0 +1,160 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "reverse_sequence.h"
#include "onnx/defs/schema.h"
// there's no way to use a raw pointer as the copy destination with std::copy_n
// (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset
// without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning.
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 4996)
#endif
#include "gsl/gsl_algorithm"
#ifdef _MSC_VER
#pragma warning(pop)
#endif
#include "core/framework/utils.h"
#include "core/framework/tensor.h"
#include "core/framework/tensor_shape.h"
namespace onnxruntime {
namespace contrib {
ONNX_OPERATOR_KERNEL_EX(ReverseSequence,
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int32_t>()),
ReverseSequenceOp);
template <typename T>
static void ReverseSequenceImpl(const Tensor& X, Tensor& Y,
gsl::span<const int> sequence_lengths,
const int64_t max_seq_len,
const int64_t batch_size,
const int64_t input_size,
bool time_major);
Status ReverseSequenceOp::Compute(OpKernelContext* context) const {
Status status = Status::OK();
const auto& X = *context->Input<Tensor>(0);
const auto data_type = X.DataType();
const auto& dims = X.Shape();
const auto batch_size = time_major_ ? dims[1] : dims[0];
const auto max_seq_len = time_major_ ? dims[0] : dims[1];
const auto input_size = dims.SizeFromDimension(2);
const auto& seq_lengths = *context->Input<Tensor>(1);
const auto& seq_len_shape = seq_lengths.Shape();
if (seq_len_shape.NumDimensions() != 1 || seq_len_shape[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "sequence_lens shape must be {batch_size}. Got:",
seq_len_shape, ". batch_size=", batch_size);
}
auto& Y = *context->Output(0, dims);
DispatchOnTensorType(data_type, ReverseSequenceImpl, X, Y, seq_lengths.DataAsSpan<int>(),
max_seq_len, batch_size, input_size, time_major_);
return status;
}
static int64_t TimeMajorInputOffset(const int64_t max_seq_len,
const int64_t batch_size,
const int64_t input_size,
const int64_t batch_num,
const int64_t seq_num) {
ORT_UNUSED_PARAMETER(max_seq_len);
return seq_num * batch_size * input_size + batch_num * input_size;
}
static int64_t BatchMajorInputOffset(const int64_t max_seq_len,
const int64_t batch_size,
const int64_t input_size,
const int64_t batch_num,
const int64_t seq_num) {
ORT_UNUSED_PARAMETER(batch_size);
return batch_num * max_seq_len * input_size + seq_num * input_size;
}
static int64_t TimeMajorOutputOffset(const int64_t max_seq_len,
const int64_t batch_size,
const int64_t input_size,
const int64_t batch_num,
const int64_t seq_num,
const int64_t seq_len) {
ORT_UNUSED_PARAMETER(max_seq_len);
return (seq_len - seq_num - 1) * batch_size * input_size + batch_num * input_size;
}
static int64_t BatchMajorOutputOffset(const int64_t max_seq_len,
const int64_t batch_size,
const int64_t input_size,
const int64_t batch_num,
const int64_t seq_num,
const int64_t seq_len) {
ORT_UNUSED_PARAMETER(batch_size);
return batch_num * max_seq_len * input_size + (seq_len - seq_num - 1) * input_size;
}
template <typename T>
static void ReverseSequenceImpl(const Tensor& X,
Tensor& Y,
gsl::span<const int> sequence_lengths,
const int64_t max_seq_len,
const int64_t batch_size,
const int64_t input_size,
bool time_major) {
gsl::span<const T> inputs = X.DataAsSpan<T>();
gsl::span<T> inputs_reverse = Y.MutableDataAsSpan<T>();
auto input_offset = time_major ? TimeMajorInputOffset : BatchMajorInputOffset;
auto reversed_output_offset = time_major ? TimeMajorOutputOffset : BatchMajorOutputOffset;
for (int i = 0; i < batch_size; i++) {
int seq_len = sequence_lengths[i];
if (seq_len == 0)
continue;
#ifdef USE_OPENMP
// Parallel execute the loop.
#pragma omp parallel for
#endif
for (int j = 0; j < seq_len; j++) {
gsl::span<const T> src = inputs.subspan(input_offset(max_seq_len, batch_size, input_size, i, j), input_size);
gsl::span<T> dest = inputs_reverse.subspan(
reversed_output_offset(max_seq_len, batch_size, input_size, i, j, seq_len), input_size);
// Use gsl::copy instead of std::copy() to allow compiler to optimize the code
gsl::copy(src, dest);
}
#ifdef USE_OPENMP
// Parallel execute the loop.
#pragma omp parallel for
#endif
for (int j = seq_len; j < max_seq_len; j++) {
const auto offset = input_offset(max_seq_len, batch_size, input_size, i, j);
gsl::span<const T> src = inputs.subspan(offset, input_size);
gsl::span<T> dest = inputs_reverse.subspan(offset, input_size);
// Use gsl::copy instead of std::copy() to allow compiler to optimize the code
gsl::copy(src, dest);
}
}
}
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/framework/tensor.h"
namespace onnxruntime {
namespace contrib {
class ReverseSequenceOp : public OpKernel {
public:
explicit ReverseSequenceOp(const OpKernelInfo& info) : OpKernel(info) {
int64_t batch_axis, time_axis;
ORT_ENFORCE(info.GetAttr<int64_t>("batch_axis", &batch_axis).IsOK());
ORT_ENFORCE(info.GetAttr<int64_t>("time_axis", &time_axis).IsOK());
ORT_ENFORCE(batch_axis < 2, "Invalid batch_axis of ", batch_axis, ". Must be 0 or 1");
ORT_ENFORCE(time_axis < 2, "Invalid time_axis of ", time_axis, ". Must be 0 or 1");
ORT_ENFORCE(batch_axis != time_axis,
"time_axis and batch_axis must have different values but both are ", time_axis);
time_major_ = time_axis == 0;
}
Status Compute(OpKernelContext* context) const override;
private:
bool time_major_;
};
} // namespace contrib
} // namespace onnxruntime

View file

@ -65,33 +65,37 @@ common::Status ExecuteGraphWithCachedInfo(const SessionState& session_state,
const bool& terminate_flag,
const logging::Logger& logger);
#define DispatchOnTensorType(tensor_type, function, ...) \
if (tensor_type == DataTypeImpl::GetType<float>()) \
function<float>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<double>()) \
function<double>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<int8_t>()) \
function<int8_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<int16_t>()) \
function<int16_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<int32_t>()) \
function<int32_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<int64_t>()) \
function<int64_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<uint8_t>()) \
function<uint8_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<uint16_t>()) \
function<uint16_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<uint32_t>()) \
function<uint32_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<uint64_t>()) \
function<uint64_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<bool>()) \
function<bool>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<MLFloat16>()) \
function<MLFloat16>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<BFloat16>()) \
function<BFloat16>(__VA_ARGS__)
#define DispatchOnTensorType(tensor_type, function, ...) \
if (tensor_type == DataTypeImpl::GetType<float>()) \
function<float>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<double>()) \
function<double>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<int8_t>()) \
function<int8_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<int16_t>()) \
function<int16_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<int32_t>()) \
function<int32_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<int64_t>()) \
function<int64_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<uint8_t>()) \
function<uint8_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<uint16_t>()) \
function<uint16_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<uint32_t>()) \
function<uint32_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<uint64_t>()) \
function<uint64_t>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<bool>()) \
function<bool>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<MLFloat16>()) \
function<MLFloat16>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<BFloat16>()) \
function<BFloat16>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<std::string>()) \
function<std::string>(__VA_ARGS__); \
else \
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type)
#define DispatchOnTensorTypeWithReturn(tensor_type, retval, function, ...) \
if (tensor_type == DataTypeImpl::GetType<float>()) \
@ -119,7 +123,11 @@ common::Status ExecuteGraphWithCachedInfo(const SessionState& session_state,
else if (tensor_type == DataTypeImpl::GetType<MLFloat16>()) \
retval = function<MLFloat16>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<BFloat16>()) \
retval = function<BFloat16>(__VA_ARGS__)
retval = function<BFloat16>(__VA_ARGS__); \
else if (tensor_type == DataTypeImpl::GetType<std::string>()) \
retval = function<std::string>(__VA_ARGS__); \
else \
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type)
} // namespace utils
} // namespace onnxruntime

View file

@ -5,6 +5,7 @@
#include "core/graph/contrib_ops/attn_lstm_schema_defs.h"
#include "core/graph/contrib_ops/contrib_defs.h"
#include "core/graph/contrib_ops/range_schema_defs.h"
#include "core/graph/contrib_ops/reverse_sequence_schema_defs.h"
#include "core/graph/op.h"
#include "onnx/defs/schema.h"
#include "onnx/defs/shape_inference.h"
@ -18,9 +19,9 @@ void convPoolTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, bool u
}
namespace onnxruntime {
namespace contrib {
using ::ONNX_NAMESPACE::AttributeProto;
using ::ONNX_NAMESPACE::OPTIONAL;
using ::ONNX_NAMESPACE::OpSchema;
using ONNX_NAMESPACE::AttributeProto;
using ONNX_NAMESPACE::OpSchema;
using ONNX_NAMESPACE::OPTIONAL;
void matmulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx) {
if (!hasInputShape(ctx, input1Idx) && !hasInputShape(ctx, input2Idx)) {
@ -221,7 +222,6 @@ void convPoolShapeInference(
}
void RegisterContribSchemas() {
// ONNX exp ops(Affine, Crop, ParametricSoftplus, ImageScaler) old version history maintainance
static const char* Affine_ver1_doc = R"DOC(
Affine takes one input data (Tensor<T>) and produces one output data
@ -585,6 +585,7 @@ activation and leaky_relu_alpha.)DOC")
ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(AttnLSTM, RegisterAttnLSTMContribOpSchema);
ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(Range, RegisterRangeOpSchema);
ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(ReverseSequence, RegisterReverseSequenceOpSchema);
static const char* Tokenizer_ver1_doc = R"DOC(
Tokenizer divides each string in X into a vector of strings along the last axis. Allowed input shapes are [C] and [N, C].

View file

@ -0,0 +1,110 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "reverse_sequence_schema_defs.h"
#include "core/graph/op.h"
namespace onnxruntime {
namespace contrib {
using ONNX_NAMESPACE::AttributeProto;
using ONNX_NAMESPACE::InferenceContext;
using ONNX_NAMESPACE::OpSchema;
using ONNX_NAMESPACE::OPTIONAL;
using ONNX_NAMESPACE::TensorProto;
using ONNX_NAMESPACE::TensorProto_DataType;
using ONNX_NAMESPACE::TensorShapeProto;
static const char* ReverseSequence_ver1_doc = R"DOC(
Reverse batch of sequences having different lengths specified by `sequence_lens`.
For each slice i iterating on batch axis, the operator reverses the first sequence_lens[i] elements on time axis,
and copies elements whose index's beyond sequence_lens[i] to the output. So the output slice i contains reversed
sequences on the first sequence_lens[i] elements, then have original values copied for the other elements.
Example 1:
input = [[0.0, 4.0, 8.0, 12.0],
[1.0, 5.0, 9.0, 13.0],
[2.0, 6.0, 10.0, 14.0],
[3.0, 7.0, 11.0, 15.0]]
sequence_lens = [4, 3, 2, 1]
time_axis = 0
batch_axis = 1
output = [[3.0, 6.0, 9.0, 12.0],
[2.0, 5.0, 8.0, 13.0],
[1.0, 4.0, 10.0, 14.0],
[0.0, 7.0, 11.0, 15.0]]
Example 2:
input = [[0.0, 1.0, 2.0, 3.0 ],
[4.0, 5.0, 6.0, 7.0 ],
[8.0, 9.0, 10.0, 11.0],
[12.0, 13.0, 14.0, 15.0]]
sequence_lens = [1, 2, 3, 4]
time_axis = 1
batch_axis = 0
output = [[0.0, 1.0, 2.0, 3.0 ],
[5.0, 4.0, 6.0, 7.0 ],
[10.0, 9.0, 8.0, 11.0],
[15.0, 14.0, 13.0, 12.0]]
)DOC";
void ReverseSequenceShapeInference(InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 2)) {
return;
}
auto& first_input_shape = getInputShape(ctx, 0);
if (first_input_shape.dim_size() < 2) {
fail_shape_inference("'input' must have rank >= 2");
}
auto& seq_len_input_shape = getInputShape(ctx, 1);
if (seq_len_input_shape.dim_size() != 1) {
fail_shape_inference("'sequence_lens' must have rank of 1");
}
propagateShapeFromInputToOutput(ctx, 0, 0);
}
OpSchema& RegisterReverseSequenceOpSchema(OpSchema&& op_schema) {
return op_schema
.SetDomain(kMSDomain)
.SinceVersion(1)
.TypeConstraint(
"T",
OpSchema::all_tensor_types(),
"Input and output types can be of any tensor type.")
.Attr(
"time_axis",
"(Optional) Specify which axis is time axis. Must be one of 0 (default), or 1.",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr(
"batch_axis",
"(Optional) Specify which axis is batch axis. Must be one of 1 (default), or 0.",
AttributeProto::INT,
static_cast<int64_t>(1))
.Input(
0,
"input",
"Tensor of rank r >= 2.",
"T")
.Input(
1,
"sequence_lens",
"Tensor specifying lengths of the sequences in a batch. It has shape `[batch_size]`.",
"tensor(int32)")
.Output(
0,
"Y",
"Tensor with same shape of input.",
"T")
.SetDoc(ReverseSequence_ver1_doc)
.TypeAndShapeInferenceFunction(ReverseSequenceShapeInference);
}
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,22 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#endif
#include "onnx/defs/schema.h"
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif
namespace onnxruntime {
namespace contrib {
ONNX_NAMESPACE::OpSchema& RegisterReverseSequenceOpSchema(ONNX_NAMESPACE::OpSchema&& op_schema);
} // namespace contrib
} // namespace onnxruntime

View file

@ -64,6 +64,14 @@ Status ShrinkImpl<bool>(const Tensor* /*input*/, Tensor* /*output*/, float /*bia
"to all numeric types only. Got bool type here.");
}
template <>
Status ShrinkImpl<std::string>(const Tensor* /*input*/, Tensor* /*output*/, float /*bias*/, float /*lambd*/) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"Input types for the Shrink operator are constrained "
"to all numeric types only. Got std::string type here.");
}
} // namespace shrink_internal
Status Shrink::Compute(OpKernelContext* p_op_kernel_context) const {
@ -76,4 +84,4 @@ Status Shrink::Compute(OpKernelContext* p_op_kernel_context) const {
DispatchOnTensorTypeWithReturn(dtype, status, ShrinkImpl, input, output, bias_, lambd_);
return status;
}
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -0,0 +1,148 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace test {
TEST(ReverseSequenceTest, BatchMajor) {
OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain);
std::vector<int64_t> input = {0, 1, 2, 3,
4, 5, 6, 7};
std::vector<int32_t> sequence_lens = {4, 3};
std::vector<int64_t> expected_output = {3, 2, 1, 0,
6, 5, 4, 7};
test.AddAttribute("batch_axis", int64_t(0));
test.AddAttribute("time_axis", int64_t(1));
test.AddInput<int64_t>("input", {2, 4, 1}, input);
test.AddInput<int32_t>("sequence_lens", {2}, sequence_lens);
test.AddOutput<int64_t>("Y", {2, 4, 1}, expected_output);
test.Run();
}
TEST(ReverseSequenceTest, TimeMajor) {
OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain);
std::vector<int64_t> input = {0, 4,
1, 5,
2, 6,
3, 7};
std::vector<int32_t> sequence_lens = {4, 3};
std::vector<int64_t> expected_output = {3, 6,
2, 5,
1, 4,
0, 7};
test.AddAttribute("batch_axis", int64_t(1));
test.AddAttribute("time_axis", int64_t(0));
test.AddInput<int64_t>("input", {4, 2, 1}, input);
test.AddInput<int32_t>("sequence_lens", {2}, sequence_lens);
test.AddOutput<int64_t>("Y", {4, 2, 1}, expected_output);
test.Run();
}
TEST(ReverseSequenceTest, LargerDim2) {
OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain);
std::vector<float> input = {0.f, 1.f,
2.f, 3.f,
4.f, 5.f,
6.f, 7.f,
8.f, 9.f,
10.f, 11.f};
std::vector<int32_t> sequence_lens = {2, 3};
std::vector<float> expected_output = {2.f, 3.f,
0.f, 1.f,
4.f, 5.f,
10.f, 11.f,
8.f, 9.f,
6.f, 7.f};
test.AddAttribute("batch_axis", int64_t(0));
test.AddAttribute("time_axis", int64_t(1));
test.AddInput<float>("input", {2, 3, 2}, input);
test.AddInput<int32_t>("sequence_lens", {2}, sequence_lens);
test.AddOutput<float>("Y", {2, 3, 2}, expected_output);
test.Run();
}
TEST(ReverseSequenceTest, Strings) {
OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain);
std::vector<std::string> input = {"0", "4 string longer than 16 chars that requires its own buffer",
"1", "5",
"2", "6",
"3", "7"};
std::vector<int32_t> sequence_lens = {4, 3};
std::vector<std::string> expected_output = {"3", "6",
"2", "5",
"1", "4 string longer than 16 chars that requires its own buffer",
"0", "7"};
test.AddAttribute("batch_axis", int64_t(1));
test.AddAttribute("time_axis", int64_t(0));
test.AddInput<std::string>("input", {4, 2, 1}, input);
test.AddInput<int32_t>("sequence_lens", {2}, sequence_lens);
test.AddOutput<std::string>("Y", {4, 2, 1}, expected_output);
test.Run();
}
TEST(ReverseSequenceTest, InvalidInput) {
{
int64_t batch_size = 2, seq_size = 4;
// Bad axis values
auto check_bad_axis = [&](int64_t batch_dim, int64_t seq_dim,
const std::vector<int64_t>& input_shape,
const std::string err_msg) {
OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain);
std::vector<int64_t> input(batch_size * seq_size, 0);
std::vector<int32_t> sequence_lens(batch_size, 1);
std::vector<int64_t> expected_output = input;
test.AddAttribute("batch_axis", batch_dim);
test.AddAttribute("time_axis", seq_dim);
test.AddInput<int64_t>("input", input_shape, input);
test.AddInput<int32_t>("sequence_lens", {batch_size}, sequence_lens);
test.AddOutput<int64_t>("Y", input_shape, expected_output);
test.Run(test::OpTester::ExpectResult::kExpectFailure, err_msg);
};
check_bad_axis(2, 1, {1, seq_size, batch_size}, "Invalid batch_axis of 2. Must be 0 or 1");
check_bad_axis(0, 2, {batch_size, 1, seq_size}, "Invalid time_axis of 2. Must be 0 or 1");
check_bad_axis(1, 1, {batch_size, seq_size, 1}, "time_axis and batch_axis must have different values but both are 1");
}
// invalid sequence_lens size
{
OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain);
// Bad data_format value
std::vector<int64_t> input = {0, 1, 2, 3,
4, 5, 6, 7};
std::vector<int32_t> sequence_lens = {4, 3, 4};
std::vector<int64_t> expected_output = {3, 2, 1, 0,
6, 5, 4, 7};
test.AddAttribute("batch_axis", int64_t(0));
test.AddAttribute("time_axis", int64_t(1));
test.AddInput<int64_t>("input", {2, 4, 1}, input);
test.AddInput<int32_t>("sequence_lens", {3}, sequence_lens);
test.AddOutput<int64_t>("Y", {2, 4, 1}, expected_output);
test.Run(test::OpTester::ExpectResult::kExpectFailure,
"sequence_lens shape must be {batch_size}. Got:{3}. batch_size=2");
}
}
} // namespace test
} // namespace onnxruntime