mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
ReverseSequence contrib op (#728)
This commit is contained in:
parent
333171f602
commit
b9b6e3abcb
9 changed files with 528 additions and 33 deletions
|
|
@ -28,6 +28,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ConvI
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ROIAlign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, ROIAlign);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearConv);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReverseSequence);
|
||||
|
||||
void RegisterContribKernels(KernelRegistry& kernel_registry) {
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp)>());
|
||||
|
|
@ -54,6 +55,7 @@ void RegisterContribKernels(KernelRegistry& kernel_registry) {
|
|||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ROIAlign)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, ROIAlign)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearConv)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReverseSequence)>());
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
|
|
|
|||
160
onnxruntime/contrib_ops/cpu/reverse_sequence.cc
Normal file
160
onnxruntime/contrib_ops/cpu/reverse_sequence.cc
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "reverse_sequence.h"
|
||||
#include "onnx/defs/schema.h"
|
||||
|
||||
// there's no way to use a raw pointer as the copy destination with std::copy_n
|
||||
// (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset
|
||||
// without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning.
|
||||
#ifdef _MSC_VER
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4996)
|
||||
#endif
|
||||
|
||||
#include "gsl/gsl_algorithm"
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#pragma warning(pop)
|
||||
#endif
|
||||
|
||||
#include "core/framework/utils.h"
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/framework/tensor_shape.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(ReverseSequence,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCpuExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int32_t>()),
|
||||
ReverseSequenceOp);
|
||||
|
||||
template <typename T>
|
||||
static void ReverseSequenceImpl(const Tensor& X, Tensor& Y,
|
||||
gsl::span<const int> sequence_lengths,
|
||||
const int64_t max_seq_len,
|
||||
const int64_t batch_size,
|
||||
const int64_t input_size,
|
||||
bool time_major);
|
||||
|
||||
Status ReverseSequenceOp::Compute(OpKernelContext* context) const {
|
||||
Status status = Status::OK();
|
||||
|
||||
const auto& X = *context->Input<Tensor>(0);
|
||||
const auto data_type = X.DataType();
|
||||
const auto& dims = X.Shape();
|
||||
|
||||
const auto batch_size = time_major_ ? dims[1] : dims[0];
|
||||
const auto max_seq_len = time_major_ ? dims[0] : dims[1];
|
||||
const auto input_size = dims.SizeFromDimension(2);
|
||||
|
||||
const auto& seq_lengths = *context->Input<Tensor>(1);
|
||||
const auto& seq_len_shape = seq_lengths.Shape();
|
||||
|
||||
if (seq_len_shape.NumDimensions() != 1 || seq_len_shape[0] != batch_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "sequence_lens shape must be {batch_size}. Got:",
|
||||
seq_len_shape, ". batch_size=", batch_size);
|
||||
}
|
||||
|
||||
auto& Y = *context->Output(0, dims);
|
||||
|
||||
DispatchOnTensorType(data_type, ReverseSequenceImpl, X, Y, seq_lengths.DataAsSpan<int>(),
|
||||
max_seq_len, batch_size, input_size, time_major_);
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
static int64_t TimeMajorInputOffset(const int64_t max_seq_len,
|
||||
const int64_t batch_size,
|
||||
const int64_t input_size,
|
||||
const int64_t batch_num,
|
||||
const int64_t seq_num) {
|
||||
ORT_UNUSED_PARAMETER(max_seq_len);
|
||||
return seq_num * batch_size * input_size + batch_num * input_size;
|
||||
}
|
||||
|
||||
static int64_t BatchMajorInputOffset(const int64_t max_seq_len,
|
||||
const int64_t batch_size,
|
||||
const int64_t input_size,
|
||||
const int64_t batch_num,
|
||||
const int64_t seq_num) {
|
||||
ORT_UNUSED_PARAMETER(batch_size);
|
||||
return batch_num * max_seq_len * input_size + seq_num * input_size;
|
||||
}
|
||||
|
||||
static int64_t TimeMajorOutputOffset(const int64_t max_seq_len,
|
||||
const int64_t batch_size,
|
||||
const int64_t input_size,
|
||||
const int64_t batch_num,
|
||||
const int64_t seq_num,
|
||||
const int64_t seq_len) {
|
||||
ORT_UNUSED_PARAMETER(max_seq_len);
|
||||
return (seq_len - seq_num - 1) * batch_size * input_size + batch_num * input_size;
|
||||
}
|
||||
|
||||
static int64_t BatchMajorOutputOffset(const int64_t max_seq_len,
|
||||
const int64_t batch_size,
|
||||
const int64_t input_size,
|
||||
const int64_t batch_num,
|
||||
const int64_t seq_num,
|
||||
const int64_t seq_len) {
|
||||
ORT_UNUSED_PARAMETER(batch_size);
|
||||
return batch_num * max_seq_len * input_size + (seq_len - seq_num - 1) * input_size;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void ReverseSequenceImpl(const Tensor& X,
|
||||
Tensor& Y,
|
||||
gsl::span<const int> sequence_lengths,
|
||||
const int64_t max_seq_len,
|
||||
const int64_t batch_size,
|
||||
const int64_t input_size,
|
||||
bool time_major) {
|
||||
gsl::span<const T> inputs = X.DataAsSpan<T>();
|
||||
gsl::span<T> inputs_reverse = Y.MutableDataAsSpan<T>();
|
||||
|
||||
auto input_offset = time_major ? TimeMajorInputOffset : BatchMajorInputOffset;
|
||||
|
||||
auto reversed_output_offset = time_major ? TimeMajorOutputOffset : BatchMajorOutputOffset;
|
||||
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
int seq_len = sequence_lengths[i];
|
||||
|
||||
if (seq_len == 0)
|
||||
continue;
|
||||
|
||||
#ifdef USE_OPENMP
|
||||
// Parallel execute the loop.
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int j = 0; j < seq_len; j++) {
|
||||
gsl::span<const T> src = inputs.subspan(input_offset(max_seq_len, batch_size, input_size, i, j), input_size);
|
||||
gsl::span<T> dest = inputs_reverse.subspan(
|
||||
reversed_output_offset(max_seq_len, batch_size, input_size, i, j, seq_len), input_size);
|
||||
|
||||
// Use gsl::copy instead of std::copy() to allow compiler to optimize the code
|
||||
gsl::copy(src, dest);
|
||||
}
|
||||
|
||||
#ifdef USE_OPENMP
|
||||
// Parallel execute the loop.
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int j = seq_len; j < max_seq_len; j++) {
|
||||
const auto offset = input_offset(max_seq_len, batch_size, input_size, i, j);
|
||||
gsl::span<const T> src = inputs.subspan(offset, input_size);
|
||||
gsl::span<T> dest = inputs_reverse.subspan(offset, input_size);
|
||||
|
||||
// Use gsl::copy instead of std::copy() to allow compiler to optimize the code
|
||||
gsl::copy(src, dest);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
36
onnxruntime/contrib_ops/cpu/reverse_sequence.h
Normal file
36
onnxruntime/contrib_ops/cpu/reverse_sequence.h
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/tensor.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
class ReverseSequenceOp : public OpKernel {
|
||||
public:
|
||||
explicit ReverseSequenceOp(const OpKernelInfo& info) : OpKernel(info) {
|
||||
int64_t batch_axis, time_axis;
|
||||
ORT_ENFORCE(info.GetAttr<int64_t>("batch_axis", &batch_axis).IsOK());
|
||||
ORT_ENFORCE(info.GetAttr<int64_t>("time_axis", &time_axis).IsOK());
|
||||
|
||||
ORT_ENFORCE(batch_axis < 2, "Invalid batch_axis of ", batch_axis, ". Must be 0 or 1");
|
||||
ORT_ENFORCE(time_axis < 2, "Invalid time_axis of ", time_axis, ". Must be 0 or 1");
|
||||
|
||||
ORT_ENFORCE(batch_axis != time_axis,
|
||||
"time_axis and batch_axis must have different values but both are ", time_axis);
|
||||
|
||||
time_major_ = time_axis == 0;
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
bool time_major_;
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -65,33 +65,37 @@ common::Status ExecuteGraphWithCachedInfo(const SessionState& session_state,
|
|||
const bool& terminate_flag,
|
||||
const logging::Logger& logger);
|
||||
|
||||
#define DispatchOnTensorType(tensor_type, function, ...) \
|
||||
if (tensor_type == DataTypeImpl::GetType<float>()) \
|
||||
function<float>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<double>()) \
|
||||
function<double>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<int8_t>()) \
|
||||
function<int8_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<int16_t>()) \
|
||||
function<int16_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<int32_t>()) \
|
||||
function<int32_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<int64_t>()) \
|
||||
function<int64_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<uint8_t>()) \
|
||||
function<uint8_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<uint16_t>()) \
|
||||
function<uint16_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<uint32_t>()) \
|
||||
function<uint32_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<uint64_t>()) \
|
||||
function<uint64_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<bool>()) \
|
||||
function<bool>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<MLFloat16>()) \
|
||||
function<MLFloat16>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<BFloat16>()) \
|
||||
function<BFloat16>(__VA_ARGS__)
|
||||
#define DispatchOnTensorType(tensor_type, function, ...) \
|
||||
if (tensor_type == DataTypeImpl::GetType<float>()) \
|
||||
function<float>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<double>()) \
|
||||
function<double>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<int8_t>()) \
|
||||
function<int8_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<int16_t>()) \
|
||||
function<int16_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<int32_t>()) \
|
||||
function<int32_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<int64_t>()) \
|
||||
function<int64_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<uint8_t>()) \
|
||||
function<uint8_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<uint16_t>()) \
|
||||
function<uint16_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<uint32_t>()) \
|
||||
function<uint32_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<uint64_t>()) \
|
||||
function<uint64_t>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<bool>()) \
|
||||
function<bool>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<MLFloat16>()) \
|
||||
function<MLFloat16>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<BFloat16>()) \
|
||||
function<BFloat16>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<std::string>()) \
|
||||
function<std::string>(__VA_ARGS__); \
|
||||
else \
|
||||
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type)
|
||||
|
||||
#define DispatchOnTensorTypeWithReturn(tensor_type, retval, function, ...) \
|
||||
if (tensor_type == DataTypeImpl::GetType<float>()) \
|
||||
|
|
@ -119,7 +123,11 @@ common::Status ExecuteGraphWithCachedInfo(const SessionState& session_state,
|
|||
else if (tensor_type == DataTypeImpl::GetType<MLFloat16>()) \
|
||||
retval = function<MLFloat16>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<BFloat16>()) \
|
||||
retval = function<BFloat16>(__VA_ARGS__)
|
||||
retval = function<BFloat16>(__VA_ARGS__); \
|
||||
else if (tensor_type == DataTypeImpl::GetType<std::string>()) \
|
||||
retval = function<std::string>(__VA_ARGS__); \
|
||||
else \
|
||||
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type)
|
||||
|
||||
} // namespace utils
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include "core/graph/contrib_ops/attn_lstm_schema_defs.h"
|
||||
#include "core/graph/contrib_ops/contrib_defs.h"
|
||||
#include "core/graph/contrib_ops/range_schema_defs.h"
|
||||
#include "core/graph/contrib_ops/reverse_sequence_schema_defs.h"
|
||||
#include "core/graph/op.h"
|
||||
#include "onnx/defs/schema.h"
|
||||
#include "onnx/defs/shape_inference.h"
|
||||
|
|
@ -18,9 +19,9 @@ void convPoolTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, bool u
|
|||
}
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
using ::ONNX_NAMESPACE::AttributeProto;
|
||||
using ::ONNX_NAMESPACE::OPTIONAL;
|
||||
using ::ONNX_NAMESPACE::OpSchema;
|
||||
using ONNX_NAMESPACE::AttributeProto;
|
||||
using ONNX_NAMESPACE::OpSchema;
|
||||
using ONNX_NAMESPACE::OPTIONAL;
|
||||
|
||||
void matmulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx) {
|
||||
if (!hasInputShape(ctx, input1Idx) && !hasInputShape(ctx, input2Idx)) {
|
||||
|
|
@ -221,7 +222,6 @@ void convPoolShapeInference(
|
|||
}
|
||||
|
||||
void RegisterContribSchemas() {
|
||||
|
||||
// ONNX exp ops(Affine, Crop, ParametricSoftplus, ImageScaler) old version history maintainance
|
||||
static const char* Affine_ver1_doc = R"DOC(
|
||||
Affine takes one input data (Tensor<T>) and produces one output data
|
||||
|
|
@ -585,6 +585,7 @@ activation and leaky_relu_alpha.)DOC")
|
|||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(AttnLSTM, RegisterAttnLSTMContribOpSchema);
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(Range, RegisterRangeOpSchema);
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(ReverseSequence, RegisterReverseSequenceOpSchema);
|
||||
|
||||
static const char* Tokenizer_ver1_doc = R"DOC(
|
||||
Tokenizer divides each string in X into a vector of strings along the last axis. Allowed input shapes are [C] and [N, C].
|
||||
|
|
|
|||
|
|
@ -0,0 +1,110 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "reverse_sequence_schema_defs.h"
|
||||
#include "core/graph/op.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
using ONNX_NAMESPACE::AttributeProto;
|
||||
using ONNX_NAMESPACE::InferenceContext;
|
||||
using ONNX_NAMESPACE::OpSchema;
|
||||
using ONNX_NAMESPACE::OPTIONAL;
|
||||
using ONNX_NAMESPACE::TensorProto;
|
||||
using ONNX_NAMESPACE::TensorProto_DataType;
|
||||
using ONNX_NAMESPACE::TensorShapeProto;
|
||||
|
||||
static const char* ReverseSequence_ver1_doc = R"DOC(
|
||||
Reverse batch of sequences having different lengths specified by `sequence_lens`.
|
||||
|
||||
For each slice i iterating on batch axis, the operator reverses the first sequence_lens[i] elements on time axis,
|
||||
and copies elements whose index's beyond sequence_lens[i] to the output. So the output slice i contains reversed
|
||||
sequences on the first sequence_lens[i] elements, then have original values copied for the other elements.
|
||||
|
||||
Example 1:
|
||||
input = [[0.0, 4.0, 8.0, 12.0],
|
||||
[1.0, 5.0, 9.0, 13.0],
|
||||
[2.0, 6.0, 10.0, 14.0],
|
||||
[3.0, 7.0, 11.0, 15.0]]
|
||||
sequence_lens = [4, 3, 2, 1]
|
||||
time_axis = 0
|
||||
batch_axis = 1
|
||||
|
||||
output = [[3.0, 6.0, 9.0, 12.0],
|
||||
[2.0, 5.0, 8.0, 13.0],
|
||||
[1.0, 4.0, 10.0, 14.0],
|
||||
[0.0, 7.0, 11.0, 15.0]]
|
||||
|
||||
Example 2:
|
||||
input = [[0.0, 1.0, 2.0, 3.0 ],
|
||||
[4.0, 5.0, 6.0, 7.0 ],
|
||||
[8.0, 9.0, 10.0, 11.0],
|
||||
[12.0, 13.0, 14.0, 15.0]]
|
||||
sequence_lens = [1, 2, 3, 4]
|
||||
time_axis = 1
|
||||
batch_axis = 0
|
||||
|
||||
output = [[0.0, 1.0, 2.0, 3.0 ],
|
||||
[5.0, 4.0, 6.0, 7.0 ],
|
||||
[10.0, 9.0, 8.0, 11.0],
|
||||
[15.0, 14.0, 13.0, 12.0]]
|
||||
)DOC";
|
||||
|
||||
void ReverseSequenceShapeInference(InferenceContext& ctx) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
if (!hasNInputShapes(ctx, 2)) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& first_input_shape = getInputShape(ctx, 0);
|
||||
if (first_input_shape.dim_size() < 2) {
|
||||
fail_shape_inference("'input' must have rank >= 2");
|
||||
}
|
||||
auto& seq_len_input_shape = getInputShape(ctx, 1);
|
||||
if (seq_len_input_shape.dim_size() != 1) {
|
||||
fail_shape_inference("'sequence_lens' must have rank of 1");
|
||||
}
|
||||
|
||||
propagateShapeFromInputToOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
OpSchema& RegisterReverseSequenceOpSchema(OpSchema&& op_schema) {
|
||||
return op_schema
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
OpSchema::all_tensor_types(),
|
||||
"Input and output types can be of any tensor type.")
|
||||
.Attr(
|
||||
"time_axis",
|
||||
"(Optional) Specify which axis is time axis. Must be one of 0 (default), or 1.",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(0))
|
||||
.Attr(
|
||||
"batch_axis",
|
||||
"(Optional) Specify which axis is batch axis. Must be one of 1 (default), or 0.",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(1))
|
||||
.Input(
|
||||
0,
|
||||
"input",
|
||||
"Tensor of rank r >= 2.",
|
||||
"T")
|
||||
.Input(
|
||||
1,
|
||||
"sequence_lens",
|
||||
"Tensor specifying lengths of the sequences in a batch. It has shape `[batch_size]`.",
|
||||
"tensor(int32)")
|
||||
.Output(
|
||||
0,
|
||||
"Y",
|
||||
"Tensor with same shape of input.",
|
||||
"T")
|
||||
.SetDoc(ReverseSequence_ver1_doc)
|
||||
.TypeAndShapeInferenceFunction(ReverseSequenceShapeInference);
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
|
||||
#pragma GCC diagnostic ignored "-Wunused-parameter"
|
||||
#endif
|
||||
#include "onnx/defs/schema.h"
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
ONNX_NAMESPACE::OpSchema& RegisterReverseSequenceOpSchema(ONNX_NAMESPACE::OpSchema&& op_schema);
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -64,6 +64,14 @@ Status ShrinkImpl<bool>(const Tensor* /*input*/, Tensor* /*output*/, float /*bia
|
|||
"to all numeric types only. Got bool type here.");
|
||||
}
|
||||
|
||||
template <>
|
||||
Status ShrinkImpl<std::string>(const Tensor* /*input*/, Tensor* /*output*/, float /*bias*/, float /*lambd*/) {
|
||||
return ORT_MAKE_STATUS(
|
||||
ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input types for the Shrink operator are constrained "
|
||||
"to all numeric types only. Got std::string type here.");
|
||||
}
|
||||
|
||||
} // namespace shrink_internal
|
||||
|
||||
Status Shrink::Compute(OpKernelContext* p_op_kernel_context) const {
|
||||
|
|
@ -76,4 +84,4 @@ Status Shrink::Compute(OpKernelContext* p_op_kernel_context) const {
|
|||
DispatchOnTensorTypeWithReturn(dtype, status, ShrinkImpl, input, output, bias_, lambd_);
|
||||
return status;
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
148
onnxruntime/test/contrib_ops/reverse_sequence_test.cc
Normal file
148
onnxruntime/test/contrib_ops/reverse_sequence_test.cc
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
TEST(ReverseSequenceTest, BatchMajor) {
|
||||
OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain);
|
||||
std::vector<int64_t> input = {0, 1, 2, 3,
|
||||
4, 5, 6, 7};
|
||||
std::vector<int32_t> sequence_lens = {4, 3};
|
||||
std::vector<int64_t> expected_output = {3, 2, 1, 0,
|
||||
6, 5, 4, 7};
|
||||
|
||||
test.AddAttribute("batch_axis", int64_t(0));
|
||||
test.AddAttribute("time_axis", int64_t(1));
|
||||
|
||||
test.AddInput<int64_t>("input", {2, 4, 1}, input);
|
||||
test.AddInput<int32_t>("sequence_lens", {2}, sequence_lens);
|
||||
test.AddOutput<int64_t>("Y", {2, 4, 1}, expected_output);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ReverseSequenceTest, TimeMajor) {
|
||||
OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain);
|
||||
std::vector<int64_t> input = {0, 4,
|
||||
1, 5,
|
||||
2, 6,
|
||||
3, 7};
|
||||
|
||||
std::vector<int32_t> sequence_lens = {4, 3};
|
||||
std::vector<int64_t> expected_output = {3, 6,
|
||||
2, 5,
|
||||
1, 4,
|
||||
0, 7};
|
||||
|
||||
test.AddAttribute("batch_axis", int64_t(1));
|
||||
test.AddAttribute("time_axis", int64_t(0));
|
||||
|
||||
test.AddInput<int64_t>("input", {4, 2, 1}, input);
|
||||
test.AddInput<int32_t>("sequence_lens", {2}, sequence_lens);
|
||||
test.AddOutput<int64_t>("Y", {4, 2, 1}, expected_output);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ReverseSequenceTest, LargerDim2) {
|
||||
OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain);
|
||||
std::vector<float> input = {0.f, 1.f,
|
||||
2.f, 3.f,
|
||||
4.f, 5.f,
|
||||
|
||||
6.f, 7.f,
|
||||
8.f, 9.f,
|
||||
10.f, 11.f};
|
||||
std::vector<int32_t> sequence_lens = {2, 3};
|
||||
std::vector<float> expected_output = {2.f, 3.f,
|
||||
0.f, 1.f,
|
||||
4.f, 5.f,
|
||||
|
||||
10.f, 11.f,
|
||||
8.f, 9.f,
|
||||
6.f, 7.f};
|
||||
|
||||
test.AddAttribute("batch_axis", int64_t(0));
|
||||
test.AddAttribute("time_axis", int64_t(1));
|
||||
|
||||
test.AddInput<float>("input", {2, 3, 2}, input);
|
||||
test.AddInput<int32_t>("sequence_lens", {2}, sequence_lens);
|
||||
test.AddOutput<float>("Y", {2, 3, 2}, expected_output);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ReverseSequenceTest, Strings) {
|
||||
OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain);
|
||||
std::vector<std::string> input = {"0", "4 string longer than 16 chars that requires its own buffer",
|
||||
"1", "5",
|
||||
"2", "6",
|
||||
"3", "7"};
|
||||
|
||||
std::vector<int32_t> sequence_lens = {4, 3};
|
||||
std::vector<std::string> expected_output = {"3", "6",
|
||||
"2", "5",
|
||||
"1", "4 string longer than 16 chars that requires its own buffer",
|
||||
"0", "7"};
|
||||
|
||||
test.AddAttribute("batch_axis", int64_t(1));
|
||||
test.AddAttribute("time_axis", int64_t(0));
|
||||
|
||||
test.AddInput<std::string>("input", {4, 2, 1}, input);
|
||||
test.AddInput<int32_t>("sequence_lens", {2}, sequence_lens);
|
||||
test.AddOutput<std::string>("Y", {4, 2, 1}, expected_output);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ReverseSequenceTest, InvalidInput) {
|
||||
{
|
||||
int64_t batch_size = 2, seq_size = 4;
|
||||
|
||||
// Bad axis values
|
||||
auto check_bad_axis = [&](int64_t batch_dim, int64_t seq_dim,
|
||||
const std::vector<int64_t>& input_shape,
|
||||
const std::string err_msg) {
|
||||
OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain);
|
||||
std::vector<int64_t> input(batch_size * seq_size, 0);
|
||||
std::vector<int32_t> sequence_lens(batch_size, 1);
|
||||
std::vector<int64_t> expected_output = input;
|
||||
|
||||
test.AddAttribute("batch_axis", batch_dim);
|
||||
test.AddAttribute("time_axis", seq_dim);
|
||||
|
||||
test.AddInput<int64_t>("input", input_shape, input);
|
||||
test.AddInput<int32_t>("sequence_lens", {batch_size}, sequence_lens);
|
||||
test.AddOutput<int64_t>("Y", input_shape, expected_output);
|
||||
test.Run(test::OpTester::ExpectResult::kExpectFailure, err_msg);
|
||||
};
|
||||
|
||||
check_bad_axis(2, 1, {1, seq_size, batch_size}, "Invalid batch_axis of 2. Must be 0 or 1");
|
||||
check_bad_axis(0, 2, {batch_size, 1, seq_size}, "Invalid time_axis of 2. Must be 0 or 1");
|
||||
check_bad_axis(1, 1, {batch_size, seq_size, 1}, "time_axis and batch_axis must have different values but both are 1");
|
||||
}
|
||||
|
||||
// invalid sequence_lens size
|
||||
{
|
||||
OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain);
|
||||
|
||||
// Bad data_format value
|
||||
std::vector<int64_t> input = {0, 1, 2, 3,
|
||||
4, 5, 6, 7};
|
||||
std::vector<int32_t> sequence_lens = {4, 3, 4};
|
||||
std::vector<int64_t> expected_output = {3, 2, 1, 0,
|
||||
6, 5, 4, 7};
|
||||
|
||||
test.AddAttribute("batch_axis", int64_t(0));
|
||||
test.AddAttribute("time_axis", int64_t(1));
|
||||
|
||||
test.AddInput<int64_t>("input", {2, 4, 1}, input);
|
||||
test.AddInput<int32_t>("sequence_lens", {3}, sequence_lens);
|
||||
test.AddOutput<int64_t>("Y", {2, 4, 1}, expected_output);
|
||||
test.Run(test::OpTester::ExpectResult::kExpectFailure,
|
||||
"sequence_lens shape must be {batch_size}. Got:{3}. batch_size=2");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue