Implement ConcatFromSequence (#2106)

This commit is contained in:
Hariharan Seshadri 2019-10-19 18:26:10 -07:00 committed by Changming Sun
parent d1159b7008
commit ac3d2ad897
7 changed files with 370 additions and 72 deletions

View file

@ -398,6 +398,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Se
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SequenceInsert);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SequenceErase);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SequenceConstruct);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ConcatFromSequence);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SplitToSequence);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ScatterND);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Gemm);
@ -1034,7 +1035,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SequenceInsert)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SequenceErase)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SequenceConstruct)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SplitToSequence)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ConcatFromSequence)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SplitToSequence)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ScatterND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, GatherElements)>,

View file

@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cpu/sequence/concat_from_sequence.h"
#include "core/framework/tensorprotoutils.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/framework/TensorSeq.h"
using namespace onnxruntime::common;
namespace onnxruntime {
ONNX_CPU_OPERATOR_KERNEL(
ConcatFromSequence,
11,
KernelDefBuilder()
.TypeConstraint("S", DataTypeImpl::AllSequenceTensorTypes()),
ConcatFromSequence);
// core Compute() method for the 'ConcatFromSequence' kernel
Status ConcatFromSequence::Compute(OpKernelContext* ctx) const {
const auto* X = ctx->Input<TensorSeq>(0);
ORT_ENFORCE(X != nullptr, "Got nullptr for sequence input.");
// number of input tensors in the Sequence to concatenate
int input_count = static_cast<int>(X->tensors.size());
// Hold pointers to the input tensors to be used in the PrepareForCompute() step
const auto& input_tensors = X->tensors;
std::vector<const Tensor*> input_tensor_pointers;
input_tensor_pointers.reserve(input_count);
for (int i = 0; i < input_count; ++i) {
input_tensor_pointers.push_back(&input_tensors[i]);
}
// Validate inputs and prepare some metadata used during actual compute
Prepare p;
auto status = PrepareForCompute(ctx, input_tensor_pointers, p);
if (!status.IsOK())
return status;
// Return at this point if output tensor is going to be empty
if (p.output_num_elements == 0)
return Status::OK();
// Compute values to be placed in the output tensor
return ComputeImpl(p);
}
} // namespace onnxruntime

View file

@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu//tensor/concat.h"
namespace onnxruntime {
class ConcatFromSequence final : public OpKernel, public ConcatBase {
public:
explicit ConcatFromSequence(const OpKernelInfo& info) : OpKernel(info), ConcatBase(info, true) {
}
Status Compute(OpKernelContext* context) const override;
};
} //namespace onnxruntime

View file

@ -3,6 +3,7 @@
#include "core/providers/cpu/tensor/concat.h"
#include "core/providers/common.h"
#include "core/framework/TensorSeq.h"
namespace onnxruntime {
@ -20,25 +21,41 @@ ONNX_CPU_OPERATOR_KERNEL(
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
Concat);
Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, int input_count, Prepare& p) const {
// this method will be shared between 'Concat' (CPU and GPU) and
// 'ConcatFromSequence' ('concat' and 'stack' modes) to validate inputs
Status ConcatBase::PrepareForCompute(OpKernelContext* ctx,
const std::vector<const Tensor*>& input_tensors,
Prepare& p) const {
int input_count = static_cast<int>(input_tensors.size());
// Must have atleast one input to concat
ORT_RETURN_IF_NOT(input_count >= 1, "Must have 1 or more inputs");
const Tensor* tensor_pointer = ctx->Input<Tensor>(0);
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
const Tensor* tensor_pointer = input_tensors[0];
if (tensor_pointer == nullptr)
return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
const Tensor& inputs_0 = *tensor_pointer;
const auto& inputs_0_dims = inputs_0.Shape().GetDims();
const size_t inputs_0_rank = inputs_0_dims.size();
ORT_RETURN_IF_NOT(inputs_0_rank > 0, "Cannot concatenate scalars");
p.axis = static_cast<uint64_t>(HandleNegativeAxis(axis_, inputs_0.Shape().NumDimensions()));
// Cannot concatenate scalars (but they can be stacked)
if (!is_stack_)
ORT_RETURN_IF_NOT(inputs_0_rank > 0, "Cannot concatenate scalars");
// Handle and fix negative axis
// In 'stack' mode, the accepted range depends on the output rank (which is one more than the input rank)
p.axis = static_cast<uint64_t>(HandleNegativeAxis(axis_, !is_stack_ ? inputs_0_rank : inputs_0_rank + 1));
// Note if input tensor is empty for later use (it's expensive to call Size() on TensorShape)
std::vector<int64_t> input_tensor_sizes(input_count);
// Assign the number of values in the first input tensor
input_tensor_sizes[0] = inputs_0.Shape().Size();
// cache num of elements in tensor for later use
// as it's expensive to call Size() on TensorShape over and over
std::vector<size_t> tensor_num_elements(static_cast<size_t>(input_count));
// Ensure all of the non concatenated axes match each other
for (int index = 1; index < input_count; index++) {
size_t num_elements = 1;
tensor_pointer = ctx->Input<Tensor>(index);
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
tensor_pointer = input_tensors[index];
if (tensor_pointer == nullptr)
return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
auto& inputs_n = *tensor_pointer;
const auto& inputs_n_dims = inputs_n.Shape().GetDims();
const size_t inputs_n_rank = inputs_n_dims.size();
@ -46,97 +63,122 @@ Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, int input_count, Prep
"Ranks of input data are different, cannot concatenate them. expected rank: ",
inputs_0_rank, " got: ", inputs_n_rank);
// Ensure all the other (non-concat) axes match
int64_t tensor_size = 1;
for (size_t axis_index = 0; axis_index < inputs_0_rank; ++axis_index) {
num_elements *= inputs_n_dims[axis_index];
if (axis_index == p.axis)
auto dim_value = inputs_n_dims[axis_index];
tensor_size *= dim_value;
// In 'concat' mode, the axis to be concatenated may be different
// But in 'stack' mode, all input shapes must be the same and must be validated
if (!is_stack_ && axis_index == p.axis)
continue;
ORT_RETURN_IF_NOT(inputs_n_dims[axis_index] == inputs_0_dims[axis_index],
ORT_RETURN_IF_NOT(dim_value == inputs_0_dims[axis_index],
"Non concat axis dimensions must match: Axis ",
axis_index, " has mismatched dimensions of ", inputs_n_dims[axis_index],
axis_index, " has mismatched dimensions of ", dim_value,
" and ", inputs_0_dims[axis_index]);
}
tensor_num_elements[index] = num_elements;
}
// Calculate the size of the concatenated axis, and verify all other dimensions match
size_t concat_axis_size = 0;
for (int index = 0; index < input_count; index++) {
tensor_pointer = ctx->Input<Tensor>(index);
concat_axis_size += tensor_pointer->Shape()[int(p.axis)];
input_tensor_sizes[index] = tensor_size; //assign the computed size of the input tensor
}
// Calculate the shape of the output tensor
std::vector<int64_t> dims(inputs_0_rank);
size_t num_elements = 1; // cache size of the first input along the way
for (size_t dimension_index = 0; dimension_index < inputs_0_rank; dimension_index++) {
dims[dimension_index] = inputs_0_dims[dimension_index];
num_elements *= inputs_0_dims[dimension_index];
}
tensor_num_elements[0] = num_elements;
dims[p.axis] = concat_axis_size;
TensorShape output_shape(dims);
std::vector<int64_t> output_dims = inputs_0_dims;
// 'Concat' mode
if (!is_stack_) {
// While concating, the rank of the output is the same as the input rank(s)
auto& concat_result = *ctx->Output(0, output_shape);
p.output_tensor = &concat_result;
// Calculate the size of the concatenated axis
size_t concat_axis_size = 0;
for (int64_t index = 0; index < input_count; index++) {
concat_axis_size += input_tensors[index]->Shape()[static_cast<int>(p.axis)];
}
output_dims[p.axis] = concat_axis_size;
} else { // 'Stack' mode
// While stacking, the rank of the output is one more than the input rank(s).
// Stacking may be thought of as adding an unit dimension (of value 1) in the input tensors,
// and concatenating them on thie new axis.
// The value in the corresponding axis of the output will be the number of inputs that are being stacked.
output_dims.insert(output_dims.begin() + p.axis, static_cast<int64_t>(input_count));
}
TensorShape output_shape(output_dims);
// Create output tensor
p.output_tensor = &(*ctx->Output(0, output_shape));
// Make note if output tensor is going to be empty
p.output_num_elements = output_shape.Size();
// if the output tensor is not going to hold any elements,
// there is no need to proceed further
// No need to proceed further if output is going to be empty
if (p.output_num_elements == 0)
return Status::OK();
// The output_axis_pitch is the number of elements to add to move to the next split axis in the output
// The output_axis_pitch is the number of elements to add to move to the next split axis in the output.
// Can handle stacking as well.
p.output_axis_pitch = 1;
for (size_t i = inputs_0_rank; i-- > p.axis;) p.output_axis_pitch *= dims[i];
auto output_rank = !is_stack_ ? inputs_0_rank : inputs_0_rank + 1;
for (size_t i = output_rank; i-- > p.axis;) {
p.output_axis_pitch *= output_dims[i];
}
// Fill the 'Prepare' struct with available information
p.inputs.reserve(input_count);
for (int input_index = 0; input_index < input_count; input_index++) {
const Tensor* data_n_ptr = ctx->Input<Tensor>(input_index);
const Tensor* data_n_ptr = input_tensors[input_index];
auto& data_n = *data_n_ptr;
ORT_RETURN_IF_NOT(data_n.DataType() == concat_result.DataType());
// Type sanity check (Make sure we are working on homogeneous types)
ORT_RETURN_IF_NOT(data_n.DataType() == p.output_tensor->DataType());
// The input_axis_pitch is the number of elements to add to move to the next split axis in the input
// Can handle stacking as well (as the "new dummy dimension" in the input is of unit value).
// TODO: Minor Optimization possibility: This input_axis_patch will be common across all inputs
// in 'ConcatFromSequence' (stack mode). They have to be computed for each input only while concatenating.
int64_t input_axis_pitch = 1;
const auto& data_dims = data_n.Shape().GetDims();
for (size_t i = inputs_0_rank; i-- > p.axis;) input_axis_pitch *= data_dims[i];
for (size_t i = inputs_0_rank; i-- > p.axis;) {
input_axis_pitch *= data_dims[i];
}
p.inputs.push_back({&data_n, tensor_num_elements[input_index], input_axis_pitch});
p.inputs.push_back({&data_n, input_axis_pitch, input_tensor_sizes[input_index]});
}
// Make note if the input Tensors of type 'string'
p.is_string_type = p.inputs[0].tensor->DataType() == DataTypeImpl::GetType<std::string>();
return Status::OK();
}
Status Concat::Compute(OpKernelContext* ctx) const {
auto input_count = Node().InputArgCount().front();
Prepare p;
ORT_RETURN_IF_ERROR(PrepareForCompute(ctx, input_count, p));
// return at this point if output tensor is going to be empty
if (p.output_num_elements == 0)
return Status::OK();
auto is_string_type = ctx->Input<Tensor>(0)->DataType() == DataTypeImpl::GetType<std::string>();
// This method computes the output tensor for Concat/ConcatFromSequence ops
Status ConcatBase::ComputeImpl(Prepare& p) const {
int input_count = static_cast<int>(p.inputs.size());
int64_t initial_output_offset = 0; // initial offset for each input
auto element_bytes = p.output_tensor->DataType()->Size();
for (int input_index = 0; input_index < input_count; input_index++) {
const auto& prep = p.inputs[input_index];
// no data in this tensor - so skip it
if (prep.num_elements == 0)
continue;
auto input_axis_pitch = prep.axis_pitch;
const uint8_t* input = static_cast<const uint8_t*>(prep.tensor->DataRaw());
auto input_size = prep.num_elements;
// Copy the data across. For every 'input_axis_pitch' values copied, we move over by the 'output_axis_pitch'
// TODO: Optimization possibility: There are cases where we simply need to "merge" raw buffers and this
// could be done without the pointer house-keeping as below. Some scenarios whether this is possible are:
// 1) Concatenating on input axis = 0
// 2) Stacking on output axis = 0
// 3) Stacking scalars
uint8_t* output = static_cast<uint8_t*>(p.output_tensor->MutableDataRaw());
int64_t cur_out_offset = 0;
int64_t cur_in_offset = 0;
for (size_t idx_copy = 0, end = input_size / input_axis_pitch; idx_copy < end; ++idx_copy) {
if (is_string_type) {
if (p.is_string_type) {
size_t out = initial_output_offset + cur_out_offset;
for (int idx_item = 0; idx_item < input_axis_pitch; ++idx_item) {
reinterpret_cast<std::string*>(output)[out + idx_item] =
@ -159,4 +201,30 @@ Status Concat::Compute(OpKernelContext* ctx) const {
return Status::OK();
}
// core Compute() method for the 'Concat' kernel
Status Concat::Compute(OpKernelContext* ctx) const {
// Number of input tensors to concatenate
auto input_count = Node().InputArgCount().front();
// Hold pointers to the input tensors to be used in the PrepareForCompute() step
std::vector<const Tensor*> input_tensors;
input_tensors.reserve(input_count);
for (int i = 0; i < input_count; ++i) {
input_tensors.push_back(ctx->Input<Tensor>(i));
}
// Validate inputs and prepare some metadata used during actual compute
Prepare p;
auto status = PrepareForCompute(ctx, input_tensors, p);
if (!status.IsOK())
return status;
// Return at this point if output tensor is going to be empty
if (p.output_num_elements == 0)
return Status::OK();
// Compute values to be placed in the output tensor
return ComputeImpl(p);
}
} // namespace onnxruntime

View file

@ -10,31 +10,45 @@
namespace onnxruntime {
// structure to hold some inputs and some metadata to be used during Compute()
struct Prepare {
struct InputInfo {
const Tensor* tensor;
int64_t axis_pitch;
int64_t num_elements;
};
std::vector<InputInfo> inputs;
int64_t output_num_elements;
int64_t output_axis_pitch;
Tensor* output_tensor;
uint64_t axis;
bool is_string_type;
};
class ConcatBase {
protected:
ConcatBase(const OpKernelInfo& info) {
if (!info.GetAttr("axis", &axis_).IsOK()) {
ConcatBase(const OpKernelInfo& info, bool is_sequence_op = false) {
if (!info.GetAttr<int64_t>("axis", &axis_).IsOK()) {
ORT_ENFORCE(false, "Must have valid 'axis' attribute");
}
is_sequence_op_ = is_sequence_op;
if (is_sequence_op) { // Only ConcatFromSequence supports stacking
is_stack_ = info.GetAttrOrDefault<int64_t>("new_axis", 0) == 0 ? false : true;
}
}
struct Prepare {
struct InputInfo {
const Tensor* tensor;
size_t num_elements;
int64_t axis_pitch;
};
std::vector<InputInfo> inputs;
int64_t output_num_elements;
int64_t output_axis_pitch;
Tensor* output_tensor;
uint64_t axis;
};
// the core method that will be invoked by the 'Concat' (CPU and GPU)
// and 'ConcatFromSequence' kernels
Status PrepareForCompute(OpKernelContext* ctx, const std::vector<const Tensor*>& input_tensors,
Prepare& p) const;
Status PrepareForCompute(OpKernelContext* ctx, int input_count, Prepare& p) const;
Status ComputeImpl(Prepare& p) const;
private:
int64_t axis_;
bool is_stack_ = false;
bool is_sequence_op_;
};
class Concat final : public OpKernel, public ConcatBase {

View file

@ -18,8 +18,15 @@ ONNX_OPERATOR_KERNEL_EX(
Status Concat::ComputeInternal(OpKernelContext* ctx) const {
auto input_count = Node().InputArgCount().front();
// Hold pointers to the input tensors to be used in the PrepareForCompute() step
std::vector<const Tensor*> input_tensors;
input_tensors.reserve(input_count);
for (int i = 0; i < input_count; ++i) {
input_tensors.push_back(ctx->Input<Tensor>(i));
}
Prepare p;
ORT_RETURN_IF_ERROR(PrepareForCompute(ctx, input_count, p));
ORT_RETURN_IF_ERROR(PrepareForCompute(ctx, input_tensors, p));
// Return at this point if output tensor is going to be empty
if (p.output_num_elements == 0)

View file

@ -0,0 +1,136 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace test {
TEST(SequenceOpsTest, ConcatFromSequence_Stack_Axis0) {
OpTester test("ConcatFromSequence", 11);
test.AddAttribute<int64_t>("axis", 0);
test.AddAttribute<int64_t>("new_axis", 1); // stack mode
SeqTensors<float> input;
input.AddTensor({1, 2}, {0.0f, 1.0f});
input.AddTensor({1, 2}, {2.0f, 3.0f});
test.AddSeqInput("S", input);
test.AddOutput<float>("I", {2, 1, 2}, {0.0f, 1.0f, 2.0f, 3.0f});
test.Run();
}
TEST(SequenceOpsTest, ConcatFromSequence_Stack_Axis1) {
OpTester test("ConcatFromSequence", 11);
test.AddAttribute<int64_t>("axis", 1);
test.AddAttribute<int64_t>("new_axis", 1); // stack mode
SeqTensors<int32_t> input;
input.AddTensor({1, 2}, {0, 1});
input.AddTensor({1, 2}, {2, 3});
test.AddSeqInput("S", input);
test.AddOutput<int32_t>("I", {1, 2, 2}, {0, 1, 2, 3});
test.Run();
}
TEST(SequenceOpsTest, ConcatFromSequence_Stack_Axis2) {
OpTester test("ConcatFromSequence", 11);
test.AddAttribute<int64_t>("axis", 2);
test.AddAttribute<int64_t>("new_axis", 1); // stack mode
SeqTensors<int64_t> input;
input.AddTensor({1, 2}, {0, 1});
input.AddTensor({1, 2}, {2, 3});
test.AddSeqInput("S", input);
test.AddOutput<int64_t>("I", {1, 2, 2}, {0, 2, 1, 3});
test.Run();
}
TEST(SequenceOpsTest, ConcatFromSequence_Stack_Axis1_WithEmptyInput) {
OpTester test("ConcatFromSequence", 11);
test.AddAttribute<int64_t>("axis", 1);
test.AddAttribute<int64_t>("new_axis", 1); // stack mode
SeqTensors<int64_t> input;
input.AddTensor({1, 0}, {});
input.AddTensor({1, 0}, {});
input.AddTensor({1, 0}, {});
test.AddSeqInput("S", input);
test.AddOutput<int64_t>("I", {1, 3, 0}, {});
test.Run();
}
TEST(SequenceOpsTest, ConcatFromSequence_Stack_ScalarInputs) {
OpTester test("ConcatFromSequence", 11);
test.AddAttribute<int64_t>("axis", 0);
test.AddAttribute<int64_t>("new_axis", 1); // stack mode
SeqTensors<int64_t> input;
input.AddTensor({}, {1});
input.AddTensor({}, {2});
input.AddTensor({}, {3});
test.AddSeqInput("S", input);
test.AddOutput<int64_t>("I", {3}, {1, 2, 3});
test.Run();
}
TEST(SequenceOpsTest, ConcatFromSequence_Concat_Axis0) {
OpTester test("ConcatFromSequence", 11);
test.AddAttribute<int64_t>("axis", 0);
test.AddAttribute<int64_t>("new_axis", 0); // concat mode
SeqTensors<float> input;
input.AddTensor({1, 2}, {0.0f, 1.0f});
input.AddTensor({1, 2}, {2.0f, 3.0f});
test.AddSeqInput("S", input);
test.AddOutput<float>("I", {2, 2}, {0.0f, 1.0f, 2.0f, 3.0f});
test.Run();
}
TEST(SequenceOpsTest, ConcatFromSequence_Concat_Axis1) {
OpTester test("ConcatFromSequence", 11);
test.AddAttribute<int64_t>("axis", 1);
test.AddAttribute<int64_t>("new_axis", 0); // concat mode
SeqTensors<int32_t> input;
input.AddTensor({1, 2}, {0, 1});
input.AddTensor({1, 2}, {2, 3});
test.AddSeqInput("S", input);
test.AddOutput<int32_t>("I", {1, 4}, {0, 1, 2, 3});
test.Run();
}
TEST(SequenceOpsTest, ConcatFromSequence_Concat_Axis2) {
OpTester test("ConcatFromSequence", 11);
test.AddAttribute<int64_t>("axis", 2);
test.AddAttribute<int64_t>("new_axis", 0); // concat mode
SeqTensors<int64_t> input;
input.AddTensor({1, 2}, {0, 1});
input.AddTensor({1, 2}, {2, 3});
test.AddSeqInput("S", input);
test.AddOutput<int64_t>("I", {1, 2, 2}, {0, 2, 1, 3});
test.Run(OpTester::ExpectResult::kExpectFailure, "axis 2 is not in valid range [-2,1]");
}
TEST(SequenceOpsTest, ConcatFromSequence_Concat_Axis1_WithEmptyInput) {
OpTester test("ConcatFromSequence", 11);
test.AddAttribute<int64_t>("axis", 1);
test.AddAttribute<int64_t>("new_axis", 0); // concat mode
SeqTensors<int64_t> input;
input.AddTensor({1, 0}, {});
input.AddTensor({1, 0}, {});
input.AddTensor({1, 0}, {});
test.AddSeqInput("S", input);
test.AddOutput<int64_t>("I", {1, 0}, {});
test.Run();
}
TEST(SequenceOpsTest, ConcatFromSequence_Concat_ScalarInputs) {
OpTester test("ConcatFromSequence", 11);
test.AddAttribute<int64_t>("axis", 0);
test.AddAttribute<int64_t>("new_axis", 0); // concat mode
SeqTensors<int64_t> input;
input.AddTensor({}, {1});
input.AddTensor({}, {2});
input.AddTensor({}, {3});
test.AddSeqInput("S", input);
test.AddOutput<int64_t>("I", {3}, {1, 2, 3});
test.Run(OpTester::ExpectResult::kExpectFailure,
"Cannot concatenate scalars");
}
} // namespace test
} // namespace onnxruntime