mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Migrate ReverseSequence from contrib op to ONNX opset 10 (#896)
* Update ONNX to 70c9026ca11b0af0050f8186bea6cab94636947f to pickup ReverseSequence op. Copy ReverseSequence from contrib ops to ONNX (keep contrib op in this commit), and update to use int64_t for sequence_lens input. * Copy ReverseSequence from contrib to ONNX and update to use int64_t for sequence_lens. Maintain contrib op in this commit. * Remove contrib op as it was temporary and only used internally. * Remove contrib op schema defs. * Cleanup contrib_defs.cc
This commit is contained in:
parent
8ed3eed7b5
commit
b8eaa88bd4
8 changed files with 29 additions and 169 deletions
|
|
@ -18,7 +18,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordC
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MurmurHash3);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MaxpoolWithMask);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReverseSequence);
|
||||
|
||||
// This section includes all opkernel declarations for former experimental ops which have now been removed from onnx.
|
||||
// To maintain backward compatibility these are added as contrib ops.
|
||||
|
|
@ -61,7 +60,6 @@ void RegisterContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MurmurHash3)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MaxpoolWithMask)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReverseSequence)>,
|
||||
|
||||
// These ops were experimental ops in onnx domain which have been removed now. We add them here as
|
||||
// contrib ops to main backward compatibility
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
#include "core/graph/contrib_ops/attn_lstm_schema_defs.h"
|
||||
#include "core/graph/contrib_ops/contrib_defs.h"
|
||||
#include "core/graph/contrib_ops/range_schema_defs.h"
|
||||
#include "core/graph/contrib_ops/reverse_sequence_schema_defs.h"
|
||||
#include "core/graph/op.h"
|
||||
#include "onnx/defs/schema.h"
|
||||
#include "onnx/defs/shape_inference.h"
|
||||
|
|
@ -729,7 +728,6 @@ activation and leaky_relu_alpha.)DOC")
|
|||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(AttnLSTM, RegisterAttnLSTMContribOpSchema);
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(Range, RegisterRangeOpSchema);
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(ReverseSequence, RegisterReverseSequenceOpSchema);
|
||||
|
||||
static const char* Tokenizer_ver1_doc = R"DOC(
|
||||
Tokenizer divides each string in X into a vector of strings along the last axis. Allowed input shapes are [C] and [N, C].
|
||||
|
|
|
|||
|
|
@ -1,110 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "reverse_sequence_schema_defs.h"
|
||||
#include "core/graph/op.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
using ONNX_NAMESPACE::AttributeProto;
|
||||
using ONNX_NAMESPACE::InferenceContext;
|
||||
using ONNX_NAMESPACE::OpSchema;
|
||||
using ONNX_NAMESPACE::OPTIONAL;
|
||||
using ONNX_NAMESPACE::TensorProto;
|
||||
using ONNX_NAMESPACE::TensorProto_DataType;
|
||||
using ONNX_NAMESPACE::TensorShapeProto;
|
||||
|
||||
static const char* ReverseSequence_ver1_doc = R"DOC(
|
||||
Reverse batch of sequences having different lengths specified by `sequence_lens`.
|
||||
|
||||
For each slice i iterating on batch axis, the operator reverses the first sequence_lens[i] elements on time axis,
|
||||
and copies elements whose index's beyond sequence_lens[i] to the output. So the output slice i contains reversed
|
||||
sequences on the first sequence_lens[i] elements, then have original values copied for the other elements.
|
||||
|
||||
Example 1:
|
||||
input = [[0.0, 4.0, 8.0, 12.0],
|
||||
[1.0, 5.0, 9.0, 13.0],
|
||||
[2.0, 6.0, 10.0, 14.0],
|
||||
[3.0, 7.0, 11.0, 15.0]]
|
||||
sequence_lens = [4, 3, 2, 1]
|
||||
time_axis = 0
|
||||
batch_axis = 1
|
||||
|
||||
output = [[3.0, 6.0, 9.0, 12.0],
|
||||
[2.0, 5.0, 8.0, 13.0],
|
||||
[1.0, 4.0, 10.0, 14.0],
|
||||
[0.0, 7.0, 11.0, 15.0]]
|
||||
|
||||
Example 2:
|
||||
input = [[0.0, 1.0, 2.0, 3.0 ],
|
||||
[4.0, 5.0, 6.0, 7.0 ],
|
||||
[8.0, 9.0, 10.0, 11.0],
|
||||
[12.0, 13.0, 14.0, 15.0]]
|
||||
sequence_lens = [1, 2, 3, 4]
|
||||
time_axis = 1
|
||||
batch_axis = 0
|
||||
|
||||
output = [[0.0, 1.0, 2.0, 3.0 ],
|
||||
[5.0, 4.0, 6.0, 7.0 ],
|
||||
[10.0, 9.0, 8.0, 11.0],
|
||||
[15.0, 14.0, 13.0, 12.0]]
|
||||
)DOC";
|
||||
|
||||
void ReverseSequenceShapeInference(InferenceContext& ctx) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
if (!hasNInputShapes(ctx, 2)) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& first_input_shape = getInputShape(ctx, 0);
|
||||
if (first_input_shape.dim_size() < 2) {
|
||||
fail_shape_inference("'input' must have rank >= 2");
|
||||
}
|
||||
auto& seq_len_input_shape = getInputShape(ctx, 1);
|
||||
if (seq_len_input_shape.dim_size() != 1) {
|
||||
fail_shape_inference("'sequence_lens' must have rank of 1");
|
||||
}
|
||||
|
||||
propagateShapeFromInputToOutput(ctx, 0, 0);
|
||||
}
|
||||
|
||||
OpSchema& RegisterReverseSequenceOpSchema(OpSchema&& op_schema) {
|
||||
return op_schema
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
OpSchema::all_tensor_types(),
|
||||
"Input and output types can be of any tensor type.")
|
||||
.Attr(
|
||||
"time_axis",
|
||||
"(Optional) Specify which axis is time axis. Must be one of 0 (default), or 1.",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(0))
|
||||
.Attr(
|
||||
"batch_axis",
|
||||
"(Optional) Specify which axis is batch axis. Must be one of 1 (default), or 0.",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(1))
|
||||
.Input(
|
||||
0,
|
||||
"input",
|
||||
"Tensor of rank r >= 2.",
|
||||
"T")
|
||||
.Input(
|
||||
1,
|
||||
"sequence_lens",
|
||||
"Tensor specifying lengths of the sequences in a batch. It has shape `[batch_size]`.",
|
||||
"tensor(int32)")
|
||||
.Output(
|
||||
0,
|
||||
"Y",
|
||||
"Tensor with same shape of input.",
|
||||
"T")
|
||||
.SetDoc(ReverseSequence_ver1_doc)
|
||||
.TypeAndShapeInferenceFunction(ReverseSequenceShapeInference);
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
|
||||
#pragma GCC diagnostic ignored "-Wunused-parameter"
|
||||
#endif
|
||||
#include "onnx/defs/schema.h"
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
ONNX_NAMESPACE::OpSchema& RegisterReverseSequenceOpSchema(ONNX_NAMESPACE::OpSchema&& op_schema);
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -270,6 +270,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, No
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, float, RoiAlign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, double, RoiAlign);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ReverseSequence);
|
||||
|
||||
void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
|
|
@ -529,6 +530,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, float, RoiAlign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, double, RoiAlign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ReverseSequence)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
|
|
@ -23,20 +23,17 @@
|
|||
#include "core/framework/tensor_shape.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(ReverseSequence,
|
||||
kMSDomain,
|
||||
1,
|
||||
kOnnxDomain,
|
||||
10,
|
||||
kCpuExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int32_t>()),
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
|
||||
ReverseSequenceOp);
|
||||
|
||||
template <typename T>
|
||||
static void ReverseSequenceImpl(const Tensor& X, Tensor& Y,
|
||||
gsl::span<const int> sequence_lengths,
|
||||
gsl::span<const int64_t> sequence_lengths,
|
||||
const int64_t max_seq_len,
|
||||
const int64_t batch_size,
|
||||
const int64_t input_size,
|
||||
|
|
@ -63,7 +60,7 @@ Status ReverseSequenceOp::Compute(OpKernelContext* context) const {
|
|||
|
||||
auto& Y = *context->Output(0, dims);
|
||||
|
||||
DispatchOnTensorType(data_type, ReverseSequenceImpl, X, Y, seq_lengths.DataAsSpan<int>(),
|
||||
DispatchOnTensorType(data_type, ReverseSequenceImpl, X, Y, seq_lengths.DataAsSpan<int64_t>(),
|
||||
max_seq_len, batch_size, input_size, time_major_);
|
||||
|
||||
return status;
|
||||
|
|
@ -110,7 +107,7 @@ static int64_t BatchMajorOutputOffset(const int64_t max_seq_len,
|
|||
template <typename T>
|
||||
static void ReverseSequenceImpl(const Tensor& X,
|
||||
Tensor& Y,
|
||||
gsl::span<const int> sequence_lengths,
|
||||
gsl::span<const int64_t> sequence_lengths,
|
||||
const int64_t max_seq_len,
|
||||
const int64_t batch_size,
|
||||
const int64_t input_size,
|
||||
|
|
@ -123,7 +120,7 @@ static void ReverseSequenceImpl(const Tensor& X,
|
|||
auto reversed_output_offset = time_major ? TimeMajorOutputOffset : BatchMajorOutputOffset;
|
||||
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
int seq_len = sequence_lengths[i];
|
||||
int64_t seq_len = sequence_lengths[i];
|
||||
|
||||
if (seq_len == 0)
|
||||
continue;
|
||||
|
|
@ -132,7 +129,7 @@ static void ReverseSequenceImpl(const Tensor& X,
|
|||
// Parallel execute the loop.
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int j = 0; j < seq_len; j++) {
|
||||
for (int64_t j = 0; j < seq_len; j++) {
|
||||
gsl::span<const T> src = inputs.subspan(input_offset(max_seq_len, batch_size, input_size, i, j), input_size);
|
||||
gsl::span<T> dest = inputs_reverse.subspan(
|
||||
reversed_output_offset(max_seq_len, batch_size, input_size, i, j, seq_len), input_size);
|
||||
|
|
@ -145,7 +142,7 @@ static void ReverseSequenceImpl(const Tensor& X,
|
|||
// Parallel execute the loop.
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int j = seq_len; j < max_seq_len; j++) {
|
||||
for (int64_t j = seq_len; j < max_seq_len; j++) {
|
||||
const auto offset = input_offset(max_seq_len, batch_size, input_size, i, j);
|
||||
gsl::span<const T> src = inputs.subspan(offset, input_size);
|
||||
gsl::span<T> dest = inputs_reverse.subspan(offset, input_size);
|
||||
|
|
@ -156,5 +153,4 @@ static void ReverseSequenceImpl(const Tensor& X,
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -8,7 +8,6 @@
|
|||
#include "core/framework/tensor.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
class ReverseSequenceOp : public OpKernel {
|
||||
public:
|
||||
|
|
@ -32,5 +31,4 @@ class ReverseSequenceOp : public OpKernel {
|
|||
bool time_major_;
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -8,10 +8,10 @@ namespace onnxruntime {
|
|||
namespace test {
|
||||
|
||||
TEST(ReverseSequenceTest, BatchMajor) {
|
||||
OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain);
|
||||
OpTester test("ReverseSequence", 10);
|
||||
std::vector<int64_t> input = {0, 1, 2, 3,
|
||||
4, 5, 6, 7};
|
||||
std::vector<int32_t> sequence_lens = {4, 3};
|
||||
std::vector<int64_t> sequence_lens = {4, 3};
|
||||
std::vector<int64_t> expected_output = {3, 2, 1, 0,
|
||||
6, 5, 4, 7};
|
||||
|
||||
|
|
@ -19,19 +19,19 @@ TEST(ReverseSequenceTest, BatchMajor) {
|
|||
test.AddAttribute("time_axis", int64_t(1));
|
||||
|
||||
test.AddInput<int64_t>("input", {2, 4, 1}, input);
|
||||
test.AddInput<int32_t>("sequence_lens", {2}, sequence_lens);
|
||||
test.AddInput<int64_t>("sequence_lens", {2}, sequence_lens);
|
||||
test.AddOutput<int64_t>("Y", {2, 4, 1}, expected_output);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ReverseSequenceTest, TimeMajor) {
|
||||
OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain);
|
||||
OpTester test("ReverseSequence", 10);
|
||||
std::vector<int64_t> input = {0, 4,
|
||||
1, 5,
|
||||
2, 6,
|
||||
3, 7};
|
||||
|
||||
std::vector<int32_t> sequence_lens = {4, 3};
|
||||
std::vector<int64_t> sequence_lens = {4, 3};
|
||||
std::vector<int64_t> expected_output = {3, 6,
|
||||
2, 5,
|
||||
1, 4,
|
||||
|
|
@ -41,13 +41,13 @@ TEST(ReverseSequenceTest, TimeMajor) {
|
|||
test.AddAttribute("time_axis", int64_t(0));
|
||||
|
||||
test.AddInput<int64_t>("input", {4, 2, 1}, input);
|
||||
test.AddInput<int32_t>("sequence_lens", {2}, sequence_lens);
|
||||
test.AddInput<int64_t>("sequence_lens", {2}, sequence_lens);
|
||||
test.AddOutput<int64_t>("Y", {4, 2, 1}, expected_output);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ReverseSequenceTest, LargerDim2) {
|
||||
OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain);
|
||||
OpTester test("ReverseSequence", 10);
|
||||
std::vector<float> input = {0.f, 1.f,
|
||||
2.f, 3.f,
|
||||
4.f, 5.f,
|
||||
|
|
@ -55,7 +55,7 @@ TEST(ReverseSequenceTest, LargerDim2) {
|
|||
6.f, 7.f,
|
||||
8.f, 9.f,
|
||||
10.f, 11.f};
|
||||
std::vector<int32_t> sequence_lens = {2, 3};
|
||||
std::vector<int64_t> sequence_lens = {2, 3};
|
||||
std::vector<float> expected_output = {2.f, 3.f,
|
||||
0.f, 1.f,
|
||||
4.f, 5.f,
|
||||
|
|
@ -68,19 +68,19 @@ TEST(ReverseSequenceTest, LargerDim2) {
|
|||
test.AddAttribute("time_axis", int64_t(1));
|
||||
|
||||
test.AddInput<float>("input", {2, 3, 2}, input);
|
||||
test.AddInput<int32_t>("sequence_lens", {2}, sequence_lens);
|
||||
test.AddInput<int64_t>("sequence_lens", {2}, sequence_lens);
|
||||
test.AddOutput<float>("Y", {2, 3, 2}, expected_output);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ReverseSequenceTest, Strings) {
|
||||
OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain);
|
||||
OpTester test("ReverseSequence", 10);
|
||||
std::vector<std::string> input = {"0", "4 string longer than 16 chars that requires its own buffer",
|
||||
"1", "5",
|
||||
"2", "6",
|
||||
"3", "7"};
|
||||
|
||||
std::vector<int32_t> sequence_lens = {4, 3};
|
||||
std::vector<int64_t> sequence_lens = {4, 3};
|
||||
std::vector<std::string> expected_output = {"3", "6",
|
||||
"2", "5",
|
||||
"1", "4 string longer than 16 chars that requires its own buffer",
|
||||
|
|
@ -90,7 +90,7 @@ TEST(ReverseSequenceTest, Strings) {
|
|||
test.AddAttribute("time_axis", int64_t(0));
|
||||
|
||||
test.AddInput<std::string>("input", {4, 2, 1}, input);
|
||||
test.AddInput<int32_t>("sequence_lens", {2}, sequence_lens);
|
||||
test.AddInput<int64_t>("sequence_lens", {2}, sequence_lens);
|
||||
test.AddOutput<std::string>("Y", {4, 2, 1}, expected_output);
|
||||
test.Run();
|
||||
}
|
||||
|
|
@ -103,16 +103,16 @@ TEST(ReverseSequenceTest, InvalidInput) {
|
|||
auto check_bad_axis = [&](int64_t batch_dim, int64_t seq_dim,
|
||||
const std::vector<int64_t>& input_shape,
|
||||
const std::string err_msg) {
|
||||
OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain);
|
||||
OpTester test("ReverseSequence", 10);
|
||||
std::vector<int64_t> input(batch_size * seq_size, 0);
|
||||
std::vector<int32_t> sequence_lens(batch_size, 1);
|
||||
std::vector<int64_t> sequence_lens(batch_size, 1);
|
||||
std::vector<int64_t> expected_output = input;
|
||||
|
||||
test.AddAttribute("batch_axis", batch_dim);
|
||||
test.AddAttribute("time_axis", seq_dim);
|
||||
|
||||
test.AddInput<int64_t>("input", input_shape, input);
|
||||
test.AddInput<int32_t>("sequence_lens", {batch_size}, sequence_lens);
|
||||
test.AddInput<int64_t>("sequence_lens", {batch_size}, sequence_lens);
|
||||
test.AddOutput<int64_t>("Y", input_shape, expected_output);
|
||||
test.Run(test::OpTester::ExpectResult::kExpectFailure, err_msg);
|
||||
};
|
||||
|
|
@ -124,12 +124,12 @@ TEST(ReverseSequenceTest, InvalidInput) {
|
|||
|
||||
// invalid sequence_lens size
|
||||
{
|
||||
OpTester test("ReverseSequence", 1, onnxruntime::kMSDomain);
|
||||
OpTester test("ReverseSequence", 10);
|
||||
|
||||
// Bad data_format value
|
||||
std::vector<int64_t> input = {0, 1, 2, 3,
|
||||
4, 5, 6, 7};
|
||||
std::vector<int32_t> sequence_lens = {4, 3, 4};
|
||||
std::vector<int64_t> sequence_lens = {4, 3, 4};
|
||||
std::vector<int64_t> expected_output = {3, 2, 1, 0,
|
||||
6, 5, 4, 7};
|
||||
|
||||
|
|
@ -137,7 +137,7 @@ TEST(ReverseSequenceTest, InvalidInput) {
|
|||
test.AddAttribute("time_axis", int64_t(1));
|
||||
|
||||
test.AddInput<int64_t>("input", {2, 4, 1}, input);
|
||||
test.AddInput<int32_t>("sequence_lens", {3}, sequence_lens);
|
||||
test.AddInput<int64_t>("sequence_lens", {3}, sequence_lens);
|
||||
test.AddOutput<int64_t>("Y", {2, 4, 1}, expected_output);
|
||||
test.Run(test::OpTester::ExpectResult::kExpectFailure,
|
||||
"sequence_lens shape must be {batch_size}. Got:{3}. batch_size=2");
|
||||
Loading…
Reference in a new issue