diff --git a/onnxruntime/contrib_ops/contrib_kernels.cc b/onnxruntime/contrib_ops/contrib_kernels.cc index 5245ce0e46..60f689ba2b 100644 --- a/onnxruntime/contrib_ops/contrib_kernels.cc +++ b/onnxruntime/contrib_ops/contrib_kernels.cc @@ -28,6 +28,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ConvI class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ROIAlign); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, ROIAlign); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearConv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReverseSequence); void RegisterContribKernels(KernelRegistry& kernel_registry) { kernel_registry.Register(BuildKernelCreateInfo()); @@ -54,6 +55,7 @@ void RegisterContribKernels(KernelRegistry& kernel_registry) { kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); } } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/reverse_sequence.cc b/onnxruntime/contrib_ops/cpu/reverse_sequence.cc new file mode 100644 index 0000000000..8b21ec8f4e --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/reverse_sequence.cc @@ -0,0 +1,160 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "reverse_sequence.h" +#include "onnx/defs/schema.h" + +// there's no way to use a raw pointer as the copy destination with std::copy_n +// (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset +// without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning. +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + +#include "gsl/gsl_algorithm" + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +#include "core/framework/utils.h" +#include "core/framework/tensor.h" +#include "core/framework/tensor_shape.h" + +namespace onnxruntime { +namespace contrib { + +ONNX_OPERATOR_KERNEL_EX(ReverseSequence, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + ReverseSequenceOp); + +template +static void ReverseSequenceImpl(const Tensor& X, Tensor& Y, + gsl::span sequence_lengths, + const int64_t max_seq_len, + const int64_t batch_size, + const int64_t input_size, + bool time_major); + +Status ReverseSequenceOp::Compute(OpKernelContext* context) const { + Status status = Status::OK(); + + const auto& X = *context->Input(0); + const auto data_type = X.DataType(); + const auto& dims = X.Shape(); + + const auto batch_size = time_major_ ? dims[1] : dims[0]; + const auto max_seq_len = time_major_ ? dims[0] : dims[1]; + const auto input_size = dims.SizeFromDimension(2); + + const auto& seq_lengths = *context->Input(1); + const auto& seq_len_shape = seq_lengths.Shape(); + + if (seq_len_shape.NumDimensions() != 1 || seq_len_shape[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "sequence_lens shape must be {batch_size}. Got:", + seq_len_shape, ". batch_size=", batch_size); + } + + auto& Y = *context->Output(0, dims); + + DispatchOnTensorType(data_type, ReverseSequenceImpl, X, Y, seq_lengths.DataAsSpan(), + max_seq_len, batch_size, input_size, time_major_); + + return status; +} + +static int64_t TimeMajorInputOffset(const int64_t max_seq_len, + const int64_t batch_size, + const int64_t input_size, + const int64_t batch_num, + const int64_t seq_num) { + ORT_UNUSED_PARAMETER(max_seq_len); + return seq_num * batch_size * input_size + batch_num * input_size; +} + +static int64_t BatchMajorInputOffset(const int64_t max_seq_len, + const int64_t batch_size, + const int64_t input_size, + const int64_t batch_num, + const int64_t seq_num) { + ORT_UNUSED_PARAMETER(batch_size); + return batch_num * max_seq_len * input_size + seq_num * input_size; +} + +static int64_t TimeMajorOutputOffset(const int64_t max_seq_len, + const int64_t batch_size, + const int64_t input_size, + const int64_t batch_num, + const int64_t seq_num, + const int64_t seq_len) { + ORT_UNUSED_PARAMETER(max_seq_len); + return (seq_len - seq_num - 1) * batch_size * input_size + batch_num * input_size; +} + +static int64_t BatchMajorOutputOffset(const int64_t max_seq_len, + const int64_t batch_size, + const int64_t input_size, + const int64_t batch_num, + const int64_t seq_num, + const int64_t seq_len) { + ORT_UNUSED_PARAMETER(batch_size); + return batch_num * max_seq_len * input_size + (seq_len - seq_num - 1) * input_size; +} + +template +static void ReverseSequenceImpl(const Tensor& X, + Tensor& Y, + gsl::span sequence_lengths, + const int64_t max_seq_len, + const int64_t batch_size, + const int64_t input_size, + bool time_major) { + gsl::span inputs = X.DataAsSpan(); + gsl::span inputs_reverse = Y.MutableDataAsSpan(); + + auto input_offset = time_major ? TimeMajorInputOffset : BatchMajorInputOffset; + + auto reversed_output_offset = time_major ? TimeMajorOutputOffset : BatchMajorOutputOffset; + + for (int i = 0; i < batch_size; i++) { + int seq_len = sequence_lengths[i]; + + if (seq_len == 0) + continue; + +#ifdef USE_OPENMP +// Parallel execute the loop. +#pragma omp parallel for +#endif + for (int j = 0; j < seq_len; j++) { + gsl::span src = inputs.subspan(input_offset(max_seq_len, batch_size, input_size, i, j), input_size); + gsl::span dest = inputs_reverse.subspan( + reversed_output_offset(max_seq_len, batch_size, input_size, i, j, seq_len), input_size); + + // Use gsl::copy instead of std::copy() to allow compiler to optimize the code + gsl::copy(src, dest); + } + +#ifdef USE_OPENMP +// Parallel execute the loop. +#pragma omp parallel for +#endif + for (int j = seq_len; j < max_seq_len; j++) { + const auto offset = input_offset(max_seq_len, batch_size, input_size, i, j); + gsl::span src = inputs.subspan(offset, input_size); + gsl::span dest = inputs_reverse.subspan(offset, input_size); + + // Use gsl::copy instead of std::copy() to allow compiler to optimize the code + gsl::copy(src, dest); + } + } +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/reverse_sequence.h b/onnxruntime/contrib_ops/cpu/reverse_sequence.h new file mode 100644 index 0000000000..f8a2c09da4 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/reverse_sequence.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/framework/tensor.h" + +namespace onnxruntime { +namespace contrib { + +class ReverseSequenceOp : public OpKernel { + public: + explicit ReverseSequenceOp(const OpKernelInfo& info) : OpKernel(info) { + int64_t batch_axis, time_axis; + ORT_ENFORCE(info.GetAttr("batch_axis", &batch_axis).IsOK()); + ORT_ENFORCE(info.GetAttr("time_axis", &time_axis).IsOK()); + + ORT_ENFORCE(batch_axis < 2, "Invalid batch_axis of ", batch_axis, ". Must be 0 or 1"); + ORT_ENFORCE(time_axis < 2, "Invalid time_axis of ", time_axis, ". Must be 0 or 1"); + + ORT_ENFORCE(batch_axis != time_axis, + "time_axis and batch_axis must have different values but both are ", time_axis); + + time_major_ = time_axis == 0; + } + + Status Compute(OpKernelContext* context) const override; + + private: + bool time_major_; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index eded5511b0..db17a22045 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -65,33 +65,37 @@ common::Status ExecuteGraphWithCachedInfo(const SessionState& session_state, const bool& terminate_flag, const logging::Logger& logger); -#define DispatchOnTensorType(tensor_type, function, ...) \ - if (tensor_type == DataTypeImpl::GetType()) \ - function(__VA_ARGS__); \ - else if (tensor_type == DataTypeImpl::GetType()) \ - function(__VA_ARGS__); \ - else if (tensor_type == DataTypeImpl::GetType()) \ - function(__VA_ARGS__); \ - else if (tensor_type == DataTypeImpl::GetType()) \ - function(__VA_ARGS__); \ - else if (tensor_type == DataTypeImpl::GetType()) \ - function(__VA_ARGS__); \ - else if (tensor_type == DataTypeImpl::GetType()) \ - function(__VA_ARGS__); \ - else if (tensor_type == DataTypeImpl::GetType()) \ - function(__VA_ARGS__); \ - else if (tensor_type == DataTypeImpl::GetType()) \ - function(__VA_ARGS__); \ - else if (tensor_type == DataTypeImpl::GetType()) \ - function(__VA_ARGS__); \ - else if (tensor_type == DataTypeImpl::GetType()) \ - function(__VA_ARGS__); \ - else if (tensor_type == DataTypeImpl::GetType()) \ - function(__VA_ARGS__); \ - else if (tensor_type == DataTypeImpl::GetType()) \ - function(__VA_ARGS__); \ - else if (tensor_type == DataTypeImpl::GetType()) \ - function(__VA_ARGS__) +#define DispatchOnTensorType(tensor_type, function, ...) \ + if (tensor_type == DataTypeImpl::GetType()) \ + function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + function(__VA_ARGS__); \ + else \ + ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type) #define DispatchOnTensorTypeWithReturn(tensor_type, retval, function, ...) \ if (tensor_type == DataTypeImpl::GetType()) \ @@ -119,7 +123,11 @@ common::Status ExecuteGraphWithCachedInfo(const SessionState& session_state, else if (tensor_type == DataTypeImpl::GetType()) \ retval = function(__VA_ARGS__); \ else if (tensor_type == DataTypeImpl::GetType()) \ - retval = function(__VA_ARGS__) + retval = function(__VA_ARGS__); \ + else if (tensor_type == DataTypeImpl::GetType()) \ + retval = function(__VA_ARGS__); \ + else \ + ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type) } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 27b580f771..472aa7c486 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -5,6 +5,7 @@ #include "core/graph/contrib_ops/attn_lstm_schema_defs.h" #include "core/graph/contrib_ops/contrib_defs.h" #include "core/graph/contrib_ops/range_schema_defs.h" +#include "core/graph/contrib_ops/reverse_sequence_schema_defs.h" #include "core/graph/op.h" #include "onnx/defs/schema.h" #include "onnx/defs/shape_inference.h" @@ -18,9 +19,9 @@ void convPoolTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, bool u } namespace onnxruntime { namespace contrib { -using ::ONNX_NAMESPACE::AttributeProto; -using ::ONNX_NAMESPACE::OPTIONAL; -using ::ONNX_NAMESPACE::OpSchema; +using ONNX_NAMESPACE::AttributeProto; +using ONNX_NAMESPACE::OpSchema; +using ONNX_NAMESPACE::OPTIONAL; void matmulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx) { if (!hasInputShape(ctx, input1Idx) && !hasInputShape(ctx, input2Idx)) { @@ -221,7 +222,6 @@ void convPoolShapeInference( } void RegisterContribSchemas() { - // ONNX exp ops(Affine, Crop, ParametricSoftplus, ImageScaler) old version history maintainance static const char* Affine_ver1_doc = R"DOC( Affine takes one input data (Tensor) and produces one output data @@ -585,6 +585,7 @@ activation and leaky_relu_alpha.)DOC") ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(AttnLSTM, RegisterAttnLSTMContribOpSchema); ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(Range, RegisterRangeOpSchema); + ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(ReverseSequence, RegisterReverseSequenceOpSchema); static const char* Tokenizer_ver1_doc = R"DOC( Tokenizer divides each string in X into a vector of strings along the last axis. Allowed input shapes are [C] and [N, C]. diff --git a/onnxruntime/core/graph/contrib_ops/reverse_sequence_schema_defs.cc b/onnxruntime/core/graph/contrib_ops/reverse_sequence_schema_defs.cc new file mode 100644 index 0000000000..393ff027a4 --- /dev/null +++ b/onnxruntime/core/graph/contrib_ops/reverse_sequence_schema_defs.cc @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "reverse_sequence_schema_defs.h" +#include "core/graph/op.h" + +namespace onnxruntime { +namespace contrib { + +using ONNX_NAMESPACE::AttributeProto; +using ONNX_NAMESPACE::InferenceContext; +using ONNX_NAMESPACE::OpSchema; +using ONNX_NAMESPACE::OPTIONAL; +using ONNX_NAMESPACE::TensorProto; +using ONNX_NAMESPACE::TensorProto_DataType; +using ONNX_NAMESPACE::TensorShapeProto; + +static const char* ReverseSequence_ver1_doc = R"DOC( +Reverse batch of sequences having different lengths specified by `sequence_lens`. + +For each slice i iterating on batch axis, the operator reverses the first sequence_lens[i] elements on time axis, +and copies elements whose index's beyond sequence_lens[i] to the output. So the output slice i contains reversed +sequences on the first sequence_lens[i] elements, then have original values copied for the other elements. + +Example 1: + input = [[0.0, 4.0, 8.0, 12.0], + [1.0, 5.0, 9.0, 13.0], + [2.0, 6.0, 10.0, 14.0], + [3.0, 7.0, 11.0, 15.0]] + sequence_lens = [4, 3, 2, 1] + time_axis = 0 + batch_axis = 1 + + output = [[3.0, 6.0, 9.0, 12.0], + [2.0, 5.0, 8.0, 13.0], + [1.0, 4.0, 10.0, 14.0], + [0.0, 7.0, 11.0, 15.0]] + +Example 2: + input = [[0.0, 1.0, 2.0, 3.0 ], + [4.0, 5.0, 6.0, 7.0 ], + [8.0, 9.0, 10.0, 11.0], + [12.0, 13.0, 14.0, 15.0]] + sequence_lens = [1, 2, 3, 4] + time_axis = 1 + batch_axis = 0 + + output = [[0.0, 1.0, 2.0, 3.0 ], + [5.0, 4.0, 6.0, 7.0 ], + [10.0, 9.0, 8.0, 11.0], + [15.0, 14.0, 13.0, 12.0]] +)DOC"; + +void ReverseSequenceShapeInference(InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (!hasNInputShapes(ctx, 2)) { + return; + } + + auto& first_input_shape = getInputShape(ctx, 0); + if (first_input_shape.dim_size() < 2) { + fail_shape_inference("'input' must have rank >= 2"); + } + auto& seq_len_input_shape = getInputShape(ctx, 1); + if (seq_len_input_shape.dim_size() != 1) { + fail_shape_inference("'sequence_lens' must have rank of 1"); + } + + propagateShapeFromInputToOutput(ctx, 0, 0); +} + +OpSchema& RegisterReverseSequenceOpSchema(OpSchema&& op_schema) { + return op_schema + .SetDomain(kMSDomain) + .SinceVersion(1) + .TypeConstraint( + "T", + OpSchema::all_tensor_types(), + "Input and output types can be of any tensor type.") + .Attr( + "time_axis", + "(Optional) Specify which axis is time axis. Must be one of 0 (default), or 1.", + AttributeProto::INT, + static_cast(0)) + .Attr( + "batch_axis", + "(Optional) Specify which axis is batch axis. Must be one of 1 (default), or 0.", + AttributeProto::INT, + static_cast(1)) + .Input( + 0, + "input", + "Tensor of rank r >= 2.", + "T") + .Input( + 1, + "sequence_lens", + "Tensor specifying lengths of the sequences in a batch. It has shape `[batch_size]`.", + "tensor(int32)") + .Output( + 0, + "Y", + "Tensor with same shape of input.", + "T") + .SetDoc(ReverseSequence_ver1_doc) + .TypeAndShapeInferenceFunction(ReverseSequenceShapeInference); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/reverse_sequence_schema_defs.h b/onnxruntime/core/graph/contrib_ops/reverse_sequence_schema_defs.h new file mode 100644 index 0000000000..7625b941b2 --- /dev/null +++ b/onnxruntime/core/graph/contrib_ops/reverse_sequence_schema_defs.h @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wignored-qualifiers" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif +#include "onnx/defs/schema.h" +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +namespace onnxruntime { +namespace contrib { + +ONNX_NAMESPACE::OpSchema& RegisterReverseSequenceOpSchema(ONNX_NAMESPACE::OpSchema&& op_schema); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/shrink.cc b/onnxruntime/core/providers/cpu/nn/shrink.cc index a2378fae5d..2248cd22ec 100644 --- a/onnxruntime/core/providers/cpu/nn/shrink.cc +++ b/onnxruntime/core/providers/cpu/nn/shrink.cc @@ -64,6 +64,14 @@ Status ShrinkImpl(const Tensor* /*input*/, Tensor* /*output*/, float /*bia "to all numeric types only. Got bool type here."); } +template <> +Status ShrinkImpl(const Tensor* /*input*/, Tensor* /*output*/, float /*bias*/, float /*lambd*/) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Input types for the Shrink operator are constrained " + "to all numeric types only. Got std::string type here."); +} + } // namespace shrink_internal Status Shrink::Compute(OpKernelContext* p_op_kernel_context) const { @@ -76,4 +84,4 @@ Status Shrink::Compute(OpKernelContext* p_op_kernel_context) const { DispatchOnTensorTypeWithReturn(dtype, status, ShrinkImpl, input, output, bias_, lambd_); return status; } -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/reverse_sequence_test.cc b/onnxruntime/test/contrib_ops/reverse_sequence_test.cc new file mode 100644 index 0000000000..3454686961 --- /dev/null +++ b/onnxruntime/test/contrib_ops/reverse_sequence_test.cc @@ -0,0 +1,148 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +TEST(ReverseSequenceTest, BatchMajor) { + OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain); + std::vector input = {0, 1, 2, 3, + 4, 5, 6, 7}; + std::vector sequence_lens = {4, 3}; + std::vector expected_output = {3, 2, 1, 0, + 6, 5, 4, 7}; + + test.AddAttribute("batch_axis", int64_t(0)); + test.AddAttribute("time_axis", int64_t(1)); + + test.AddInput("input", {2, 4, 1}, input); + test.AddInput("sequence_lens", {2}, sequence_lens); + test.AddOutput("Y", {2, 4, 1}, expected_output); + test.Run(); +} + +TEST(ReverseSequenceTest, TimeMajor) { + OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain); + std::vector input = {0, 4, + 1, 5, + 2, 6, + 3, 7}; + + std::vector sequence_lens = {4, 3}; + std::vector expected_output = {3, 6, + 2, 5, + 1, 4, + 0, 7}; + + test.AddAttribute("batch_axis", int64_t(1)); + test.AddAttribute("time_axis", int64_t(0)); + + test.AddInput("input", {4, 2, 1}, input); + test.AddInput("sequence_lens", {2}, sequence_lens); + test.AddOutput("Y", {4, 2, 1}, expected_output); + test.Run(); +} + +TEST(ReverseSequenceTest, LargerDim2) { + OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain); + std::vector input = {0.f, 1.f, + 2.f, 3.f, + 4.f, 5.f, + + 6.f, 7.f, + 8.f, 9.f, + 10.f, 11.f}; + std::vector sequence_lens = {2, 3}; + std::vector expected_output = {2.f, 3.f, + 0.f, 1.f, + 4.f, 5.f, + + 10.f, 11.f, + 8.f, 9.f, + 6.f, 7.f}; + + test.AddAttribute("batch_axis", int64_t(0)); + test.AddAttribute("time_axis", int64_t(1)); + + test.AddInput("input", {2, 3, 2}, input); + test.AddInput("sequence_lens", {2}, sequence_lens); + test.AddOutput("Y", {2, 3, 2}, expected_output); + test.Run(); +} + +TEST(ReverseSequenceTest, Strings) { + OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain); + std::vector input = {"0", "4 string longer than 16 chars that requires its own buffer", + "1", "5", + "2", "6", + "3", "7"}; + + std::vector sequence_lens = {4, 3}; + std::vector expected_output = {"3", "6", + "2", "5", + "1", "4 string longer than 16 chars that requires its own buffer", + "0", "7"}; + + test.AddAttribute("batch_axis", int64_t(1)); + test.AddAttribute("time_axis", int64_t(0)); + + test.AddInput("input", {4, 2, 1}, input); + test.AddInput("sequence_lens", {2}, sequence_lens); + test.AddOutput("Y", {4, 2, 1}, expected_output); + test.Run(); +} + +TEST(ReverseSequenceTest, InvalidInput) { + { + int64_t batch_size = 2, seq_size = 4; + + // Bad axis values + auto check_bad_axis = [&](int64_t batch_dim, int64_t seq_dim, + const std::vector& input_shape, + const std::string err_msg) { + OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain); + std::vector input(batch_size * seq_size, 0); + std::vector sequence_lens(batch_size, 1); + std::vector expected_output = input; + + test.AddAttribute("batch_axis", batch_dim); + test.AddAttribute("time_axis", seq_dim); + + test.AddInput("input", input_shape, input); + test.AddInput("sequence_lens", {batch_size}, sequence_lens); + test.AddOutput("Y", input_shape, expected_output); + test.Run(test::OpTester::ExpectResult::kExpectFailure, err_msg); + }; + + check_bad_axis(2, 1, {1, seq_size, batch_size}, "Invalid batch_axis of 2. Must be 0 or 1"); + check_bad_axis(0, 2, {batch_size, 1, seq_size}, "Invalid time_axis of 2. Must be 0 or 1"); + check_bad_axis(1, 1, {batch_size, seq_size, 1}, "time_axis and batch_axis must have different values but both are 1"); + } + + // invalid sequence_lens size + { + OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain); + + // Bad data_format value + std::vector input = {0, 1, 2, 3, + 4, 5, 6, 7}; + std::vector sequence_lens = {4, 3, 4}; + std::vector expected_output = {3, 2, 1, 0, + 6, 5, 4, 7}; + + test.AddAttribute("batch_axis", int64_t(0)); + test.AddAttribute("time_axis", int64_t(1)); + + test.AddInput("input", {2, 4, 1}, input); + test.AddInput("sequence_lens", {3}, sequence_lens); + test.AddOutput("Y", {2, 4, 1}, expected_output); + test.Run(test::OpTester::ExpectResult::kExpectFailure, + "sequence_lens shape must be {batch_size}. Got:{3}. batch_size=2"); + } +} + +} // namespace test +} // namespace onnxruntime