mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
Implement ConcatFromSequence (#2106)
This commit is contained in:
parent
d1159b7008
commit
ac3d2ad897
7 changed files with 370 additions and 72 deletions
|
|
@ -398,6 +398,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Se
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SequenceInsert);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SequenceErase);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SequenceConstruct);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ConcatFromSequence);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SplitToSequence);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ScatterND);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Gemm);
|
||||
|
|
@ -1034,7 +1035,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SequenceInsert)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SequenceErase)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SequenceConstruct)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SplitToSequence)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ConcatFromSequence)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, SplitToSequence)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ScatterND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, GatherElements)>,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,51 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cpu/sequence/concat_from_sequence.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "core/framework/TensorSeq.h"
|
||||
|
||||
using namespace onnxruntime::common;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
ConcatFromSequence,
|
||||
11,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("S", DataTypeImpl::AllSequenceTensorTypes()),
|
||||
ConcatFromSequence);
|
||||
|
||||
// core Compute() method for the 'ConcatFromSequence' kernel
|
||||
Status ConcatFromSequence::Compute(OpKernelContext* ctx) const {
|
||||
const auto* X = ctx->Input<TensorSeq>(0);
|
||||
ORT_ENFORCE(X != nullptr, "Got nullptr for sequence input.");
|
||||
|
||||
// number of input tensors in the Sequence to concatenate
|
||||
int input_count = static_cast<int>(X->tensors.size());
|
||||
|
||||
// Hold pointers to the input tensors to be used in the PrepareForCompute() step
|
||||
const auto& input_tensors = X->tensors;
|
||||
|
||||
std::vector<const Tensor*> input_tensor_pointers;
|
||||
input_tensor_pointers.reserve(input_count);
|
||||
for (int i = 0; i < input_count; ++i) {
|
||||
input_tensor_pointers.push_back(&input_tensors[i]);
|
||||
}
|
||||
|
||||
// Validate inputs and prepare some metadata used during actual compute
|
||||
Prepare p;
|
||||
auto status = PrepareForCompute(ctx, input_tensor_pointers, p);
|
||||
if (!status.IsOK())
|
||||
return status;
|
||||
|
||||
// Return at this point if output tensor is going to be empty
|
||||
if (p.output_num_elements == 0)
|
||||
return Status::OK();
|
||||
|
||||
// Compute values to be placed in the output tensor
|
||||
return ComputeImpl(p);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/cpu//tensor/concat.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
class ConcatFromSequence final : public OpKernel, public ConcatBase {
|
||||
public:
|
||||
explicit ConcatFromSequence(const OpKernelInfo& info) : OpKernel(info), ConcatBase(info, true) {
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
} //namespace onnxruntime
|
||||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include "core/providers/cpu/tensor/concat.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/framework/TensorSeq.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -20,25 +21,41 @@ ONNX_CPU_OPERATOR_KERNEL(
|
|||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
|
||||
Concat);
|
||||
|
||||
Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, int input_count, Prepare& p) const {
|
||||
// this method will be shared between 'Concat' (CPU and GPU) and
|
||||
// 'ConcatFromSequence' ('concat' and 'stack' modes) to validate inputs
|
||||
Status ConcatBase::PrepareForCompute(OpKernelContext* ctx,
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
Prepare& p) const {
|
||||
int input_count = static_cast<int>(input_tensors.size());
|
||||
// Must have atleast one input to concat
|
||||
ORT_RETURN_IF_NOT(input_count >= 1, "Must have 1 or more inputs");
|
||||
const Tensor* tensor_pointer = ctx->Input<Tensor>(0);
|
||||
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
|
||||
|
||||
const Tensor* tensor_pointer = input_tensors[0];
|
||||
if (tensor_pointer == nullptr)
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
|
||||
|
||||
const Tensor& inputs_0 = *tensor_pointer;
|
||||
const auto& inputs_0_dims = inputs_0.Shape().GetDims();
|
||||
const size_t inputs_0_rank = inputs_0_dims.size();
|
||||
ORT_RETURN_IF_NOT(inputs_0_rank > 0, "Cannot concatenate scalars");
|
||||
|
||||
p.axis = static_cast<uint64_t>(HandleNegativeAxis(axis_, inputs_0.Shape().NumDimensions()));
|
||||
// Cannot concatenate scalars (but they can be stacked)
|
||||
if (!is_stack_)
|
||||
ORT_RETURN_IF_NOT(inputs_0_rank > 0, "Cannot concatenate scalars");
|
||||
|
||||
// Handle and fix negative axis
|
||||
// In 'stack' mode, the accepted range depends on the output rank (which is one more than the input rank)
|
||||
p.axis = static_cast<uint64_t>(HandleNegativeAxis(axis_, !is_stack_ ? inputs_0_rank : inputs_0_rank + 1));
|
||||
|
||||
// Note if input tensor is empty for later use (it's expensive to call Size() on TensorShape)
|
||||
std::vector<int64_t> input_tensor_sizes(input_count);
|
||||
// Assign the number of values in the first input tensor
|
||||
input_tensor_sizes[0] = inputs_0.Shape().Size();
|
||||
|
||||
// cache num of elements in tensor for later use
|
||||
// as it's expensive to call Size() on TensorShape over and over
|
||||
std::vector<size_t> tensor_num_elements(static_cast<size_t>(input_count));
|
||||
// Ensure all of the non concatenated axes match each other
|
||||
for (int index = 1; index < input_count; index++) {
|
||||
size_t num_elements = 1;
|
||||
tensor_pointer = ctx->Input<Tensor>(index);
|
||||
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
|
||||
tensor_pointer = input_tensors[index];
|
||||
if (tensor_pointer == nullptr)
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
|
||||
auto& inputs_n = *tensor_pointer;
|
||||
const auto& inputs_n_dims = inputs_n.Shape().GetDims();
|
||||
const size_t inputs_n_rank = inputs_n_dims.size();
|
||||
|
|
@ -46,97 +63,122 @@ Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, int input_count, Prep
|
|||
"Ranks of input data are different, cannot concatenate them. expected rank: ",
|
||||
inputs_0_rank, " got: ", inputs_n_rank);
|
||||
// Ensure all the other (non-concat) axes match
|
||||
int64_t tensor_size = 1;
|
||||
for (size_t axis_index = 0; axis_index < inputs_0_rank; ++axis_index) {
|
||||
num_elements *= inputs_n_dims[axis_index];
|
||||
if (axis_index == p.axis)
|
||||
auto dim_value = inputs_n_dims[axis_index];
|
||||
tensor_size *= dim_value;
|
||||
|
||||
// In 'concat' mode, the axis to be concatenated may be different
|
||||
// But in 'stack' mode, all input shapes must be the same and must be validated
|
||||
if (!is_stack_ && axis_index == p.axis)
|
||||
continue;
|
||||
ORT_RETURN_IF_NOT(inputs_n_dims[axis_index] == inputs_0_dims[axis_index],
|
||||
|
||||
ORT_RETURN_IF_NOT(dim_value == inputs_0_dims[axis_index],
|
||||
"Non concat axis dimensions must match: Axis ",
|
||||
axis_index, " has mismatched dimensions of ", inputs_n_dims[axis_index],
|
||||
axis_index, " has mismatched dimensions of ", dim_value,
|
||||
" and ", inputs_0_dims[axis_index]);
|
||||
}
|
||||
tensor_num_elements[index] = num_elements;
|
||||
}
|
||||
|
||||
// Calculate the size of the concatenated axis, and verify all other dimensions match
|
||||
size_t concat_axis_size = 0;
|
||||
for (int index = 0; index < input_count; index++) {
|
||||
tensor_pointer = ctx->Input<Tensor>(index);
|
||||
concat_axis_size += tensor_pointer->Shape()[int(p.axis)];
|
||||
input_tensor_sizes[index] = tensor_size; //assign the computed size of the input tensor
|
||||
}
|
||||
|
||||
// Calculate the shape of the output tensor
|
||||
std::vector<int64_t> dims(inputs_0_rank);
|
||||
size_t num_elements = 1; // cache size of the first input along the way
|
||||
for (size_t dimension_index = 0; dimension_index < inputs_0_rank; dimension_index++) {
|
||||
dims[dimension_index] = inputs_0_dims[dimension_index];
|
||||
num_elements *= inputs_0_dims[dimension_index];
|
||||
}
|
||||
tensor_num_elements[0] = num_elements;
|
||||
dims[p.axis] = concat_axis_size;
|
||||
TensorShape output_shape(dims);
|
||||
std::vector<int64_t> output_dims = inputs_0_dims;
|
||||
// 'Concat' mode
|
||||
if (!is_stack_) {
|
||||
// While concating, the rank of the output is the same as the input rank(s)
|
||||
|
||||
auto& concat_result = *ctx->Output(0, output_shape);
|
||||
p.output_tensor = &concat_result;
|
||||
// Calculate the size of the concatenated axis
|
||||
size_t concat_axis_size = 0;
|
||||
for (int64_t index = 0; index < input_count; index++) {
|
||||
concat_axis_size += input_tensors[index]->Shape()[static_cast<int>(p.axis)];
|
||||
}
|
||||
|
||||
output_dims[p.axis] = concat_axis_size;
|
||||
} else { // 'Stack' mode
|
||||
// While stacking, the rank of the output is one more than the input rank(s).
|
||||
// Stacking may be thought of as adding an unit dimension (of value 1) in the input tensors,
|
||||
// and concatenating them on thie new axis.
|
||||
// The value in the corresponding axis of the output will be the number of inputs that are being stacked.
|
||||
output_dims.insert(output_dims.begin() + p.axis, static_cast<int64_t>(input_count));
|
||||
}
|
||||
|
||||
TensorShape output_shape(output_dims);
|
||||
|
||||
// Create output tensor
|
||||
p.output_tensor = &(*ctx->Output(0, output_shape));
|
||||
|
||||
// Make note if output tensor is going to be empty
|
||||
p.output_num_elements = output_shape.Size();
|
||||
|
||||
// if the output tensor is not going to hold any elements,
|
||||
// there is no need to proceed further
|
||||
// No need to proceed further if output is going to be empty
|
||||
if (p.output_num_elements == 0)
|
||||
return Status::OK();
|
||||
|
||||
// The output_axis_pitch is the number of elements to add to move to the next split axis in the output
|
||||
// The output_axis_pitch is the number of elements to add to move to the next split axis in the output.
|
||||
// Can handle stacking as well.
|
||||
p.output_axis_pitch = 1;
|
||||
for (size_t i = inputs_0_rank; i-- > p.axis;) p.output_axis_pitch *= dims[i];
|
||||
auto output_rank = !is_stack_ ? inputs_0_rank : inputs_0_rank + 1;
|
||||
for (size_t i = output_rank; i-- > p.axis;) {
|
||||
p.output_axis_pitch *= output_dims[i];
|
||||
}
|
||||
|
||||
// Fill the 'Prepare' struct with available information
|
||||
p.inputs.reserve(input_count);
|
||||
for (int input_index = 0; input_index < input_count; input_index++) {
|
||||
const Tensor* data_n_ptr = ctx->Input<Tensor>(input_index);
|
||||
const Tensor* data_n_ptr = input_tensors[input_index];
|
||||
auto& data_n = *data_n_ptr;
|
||||
|
||||
ORT_RETURN_IF_NOT(data_n.DataType() == concat_result.DataType());
|
||||
// Type sanity check (Make sure we are working on homogeneous types)
|
||||
ORT_RETURN_IF_NOT(data_n.DataType() == p.output_tensor->DataType());
|
||||
|
||||
// The input_axis_pitch is the number of elements to add to move to the next split axis in the input
|
||||
// Can handle stacking as well (as the "new dummy dimension" in the input is of unit value).
|
||||
// TODO: Minor Optimization possibility: This input_axis_patch will be common across all inputs
|
||||
// in 'ConcatFromSequence' (stack mode). They have to be computed for each input only while concatenating.
|
||||
int64_t input_axis_pitch = 1;
|
||||
const auto& data_dims = data_n.Shape().GetDims();
|
||||
for (size_t i = inputs_0_rank; i-- > p.axis;) input_axis_pitch *= data_dims[i];
|
||||
for (size_t i = inputs_0_rank; i-- > p.axis;) {
|
||||
input_axis_pitch *= data_dims[i];
|
||||
}
|
||||
|
||||
p.inputs.push_back({&data_n, tensor_num_elements[input_index], input_axis_pitch});
|
||||
p.inputs.push_back({&data_n, input_axis_pitch, input_tensor_sizes[input_index]});
|
||||
}
|
||||
|
||||
// Make note if the input Tensors of type 'string'
|
||||
p.is_string_type = p.inputs[0].tensor->DataType() == DataTypeImpl::GetType<std::string>();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Concat::Compute(OpKernelContext* ctx) const {
|
||||
auto input_count = Node().InputArgCount().front();
|
||||
|
||||
Prepare p;
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(ctx, input_count, p));
|
||||
|
||||
// return at this point if output tensor is going to be empty
|
||||
if (p.output_num_elements == 0)
|
||||
return Status::OK();
|
||||
|
||||
auto is_string_type = ctx->Input<Tensor>(0)->DataType() == DataTypeImpl::GetType<std::string>();
|
||||
|
||||
// This method computes the output tensor for Concat/ConcatFromSequence ops
|
||||
Status ConcatBase::ComputeImpl(Prepare& p) const {
|
||||
int input_count = static_cast<int>(p.inputs.size());
|
||||
int64_t initial_output_offset = 0; // initial offset for each input
|
||||
auto element_bytes = p.output_tensor->DataType()->Size();
|
||||
for (int input_index = 0; input_index < input_count; input_index++) {
|
||||
const auto& prep = p.inputs[input_index];
|
||||
|
||||
// no data in this tensor - so skip it
|
||||
if (prep.num_elements == 0)
|
||||
continue;
|
||||
|
||||
auto input_axis_pitch = prep.axis_pitch;
|
||||
const uint8_t* input = static_cast<const uint8_t*>(prep.tensor->DataRaw());
|
||||
|
||||
auto input_size = prep.num_elements;
|
||||
|
||||
// Copy the data across. For every 'input_axis_pitch' values copied, we move over by the 'output_axis_pitch'
|
||||
// TODO: Optimization possibility: There are cases where we simply need to "merge" raw buffers and this
|
||||
// could be done without the pointer house-keeping as below. Some scenarios whether this is possible are:
|
||||
// 1) Concatenating on input axis = 0
|
||||
// 2) Stacking on output axis = 0
|
||||
// 3) Stacking scalars
|
||||
uint8_t* output = static_cast<uint8_t*>(p.output_tensor->MutableDataRaw());
|
||||
int64_t cur_out_offset = 0;
|
||||
int64_t cur_in_offset = 0;
|
||||
for (size_t idx_copy = 0, end = input_size / input_axis_pitch; idx_copy < end; ++idx_copy) {
|
||||
if (is_string_type) {
|
||||
if (p.is_string_type) {
|
||||
size_t out = initial_output_offset + cur_out_offset;
|
||||
for (int idx_item = 0; idx_item < input_axis_pitch; ++idx_item) {
|
||||
reinterpret_cast<std::string*>(output)[out + idx_item] =
|
||||
|
|
@ -159,4 +201,30 @@ Status Concat::Compute(OpKernelContext* ctx) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// core Compute() method for the 'Concat' kernel
|
||||
Status Concat::Compute(OpKernelContext* ctx) const {
|
||||
// Number of input tensors to concatenate
|
||||
auto input_count = Node().InputArgCount().front();
|
||||
|
||||
// Hold pointers to the input tensors to be used in the PrepareForCompute() step
|
||||
std::vector<const Tensor*> input_tensors;
|
||||
input_tensors.reserve(input_count);
|
||||
for (int i = 0; i < input_count; ++i) {
|
||||
input_tensors.push_back(ctx->Input<Tensor>(i));
|
||||
}
|
||||
|
||||
// Validate inputs and prepare some metadata used during actual compute
|
||||
Prepare p;
|
||||
auto status = PrepareForCompute(ctx, input_tensors, p);
|
||||
if (!status.IsOK())
|
||||
return status;
|
||||
|
||||
// Return at this point if output tensor is going to be empty
|
||||
if (p.output_num_elements == 0)
|
||||
return Status::OK();
|
||||
|
||||
// Compute values to be placed in the output tensor
|
||||
return ComputeImpl(p);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -10,31 +10,45 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
// structure to hold some inputs and some metadata to be used during Compute()
|
||||
struct Prepare {
|
||||
struct InputInfo {
|
||||
const Tensor* tensor;
|
||||
int64_t axis_pitch;
|
||||
int64_t num_elements;
|
||||
};
|
||||
std::vector<InputInfo> inputs;
|
||||
int64_t output_num_elements;
|
||||
int64_t output_axis_pitch;
|
||||
Tensor* output_tensor;
|
||||
uint64_t axis;
|
||||
bool is_string_type;
|
||||
};
|
||||
|
||||
class ConcatBase {
|
||||
protected:
|
||||
ConcatBase(const OpKernelInfo& info) {
|
||||
if (!info.GetAttr("axis", &axis_).IsOK()) {
|
||||
ConcatBase(const OpKernelInfo& info, bool is_sequence_op = false) {
|
||||
if (!info.GetAttr<int64_t>("axis", &axis_).IsOK()) {
|
||||
ORT_ENFORCE(false, "Must have valid 'axis' attribute");
|
||||
}
|
||||
|
||||
is_sequence_op_ = is_sequence_op;
|
||||
|
||||
if (is_sequence_op) { // Only ConcatFromSequence supports stacking
|
||||
is_stack_ = info.GetAttrOrDefault<int64_t>("new_axis", 0) == 0 ? false : true;
|
||||
}
|
||||
}
|
||||
|
||||
struct Prepare {
|
||||
struct InputInfo {
|
||||
const Tensor* tensor;
|
||||
size_t num_elements;
|
||||
int64_t axis_pitch;
|
||||
};
|
||||
std::vector<InputInfo> inputs;
|
||||
int64_t output_num_elements;
|
||||
int64_t output_axis_pitch;
|
||||
Tensor* output_tensor;
|
||||
uint64_t axis;
|
||||
};
|
||||
// the core method that will be invoked by the 'Concat' (CPU and GPU)
|
||||
// and 'ConcatFromSequence' kernels
|
||||
Status PrepareForCompute(OpKernelContext* ctx, const std::vector<const Tensor*>& input_tensors,
|
||||
Prepare& p) const;
|
||||
|
||||
Status PrepareForCompute(OpKernelContext* ctx, int input_count, Prepare& p) const;
|
||||
Status ComputeImpl(Prepare& p) const;
|
||||
|
||||
private:
|
||||
int64_t axis_;
|
||||
bool is_stack_ = false;
|
||||
bool is_sequence_op_;
|
||||
};
|
||||
|
||||
class Concat final : public OpKernel, public ConcatBase {
|
||||
|
|
|
|||
|
|
@ -18,8 +18,15 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
Status Concat::ComputeInternal(OpKernelContext* ctx) const {
|
||||
auto input_count = Node().InputArgCount().front();
|
||||
|
||||
// Hold pointers to the input tensors to be used in the PrepareForCompute() step
|
||||
std::vector<const Tensor*> input_tensors;
|
||||
input_tensors.reserve(input_count);
|
||||
for (int i = 0; i < input_count; ++i) {
|
||||
input_tensors.push_back(ctx->Input<Tensor>(i));
|
||||
}
|
||||
|
||||
Prepare p;
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(ctx, input_count, p));
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(ctx, input_tensors, p));
|
||||
|
||||
// Return at this point if output tensor is going to be empty
|
||||
if (p.output_num_elements == 0)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,136 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
TEST(SequenceOpsTest, ConcatFromSequence_Stack_Axis0) {
|
||||
OpTester test("ConcatFromSequence", 11);
|
||||
test.AddAttribute<int64_t>("axis", 0);
|
||||
test.AddAttribute<int64_t>("new_axis", 1); // stack mode
|
||||
SeqTensors<float> input;
|
||||
input.AddTensor({1, 2}, {0.0f, 1.0f});
|
||||
input.AddTensor({1, 2}, {2.0f, 3.0f});
|
||||
test.AddSeqInput("S", input);
|
||||
test.AddOutput<float>("I", {2, 1, 2}, {0.0f, 1.0f, 2.0f, 3.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(SequenceOpsTest, ConcatFromSequence_Stack_Axis1) {
|
||||
OpTester test("ConcatFromSequence", 11);
|
||||
test.AddAttribute<int64_t>("axis", 1);
|
||||
test.AddAttribute<int64_t>("new_axis", 1); // stack mode
|
||||
SeqTensors<int32_t> input;
|
||||
input.AddTensor({1, 2}, {0, 1});
|
||||
input.AddTensor({1, 2}, {2, 3});
|
||||
test.AddSeqInput("S", input);
|
||||
test.AddOutput<int32_t>("I", {1, 2, 2}, {0, 1, 2, 3});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(SequenceOpsTest, ConcatFromSequence_Stack_Axis2) {
|
||||
OpTester test("ConcatFromSequence", 11);
|
||||
test.AddAttribute<int64_t>("axis", 2);
|
||||
test.AddAttribute<int64_t>("new_axis", 1); // stack mode
|
||||
SeqTensors<int64_t> input;
|
||||
input.AddTensor({1, 2}, {0, 1});
|
||||
input.AddTensor({1, 2}, {2, 3});
|
||||
test.AddSeqInput("S", input);
|
||||
test.AddOutput<int64_t>("I", {1, 2, 2}, {0, 2, 1, 3});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(SequenceOpsTest, ConcatFromSequence_Stack_Axis1_WithEmptyInput) {
|
||||
OpTester test("ConcatFromSequence", 11);
|
||||
test.AddAttribute<int64_t>("axis", 1);
|
||||
test.AddAttribute<int64_t>("new_axis", 1); // stack mode
|
||||
SeqTensors<int64_t> input;
|
||||
input.AddTensor({1, 0}, {});
|
||||
input.AddTensor({1, 0}, {});
|
||||
input.AddTensor({1, 0}, {});
|
||||
test.AddSeqInput("S", input);
|
||||
test.AddOutput<int64_t>("I", {1, 3, 0}, {});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(SequenceOpsTest, ConcatFromSequence_Stack_ScalarInputs) {
|
||||
OpTester test("ConcatFromSequence", 11);
|
||||
test.AddAttribute<int64_t>("axis", 0);
|
||||
test.AddAttribute<int64_t>("new_axis", 1); // stack mode
|
||||
SeqTensors<int64_t> input;
|
||||
input.AddTensor({}, {1});
|
||||
input.AddTensor({}, {2});
|
||||
input.AddTensor({}, {3});
|
||||
test.AddSeqInput("S", input);
|
||||
test.AddOutput<int64_t>("I", {3}, {1, 2, 3});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(SequenceOpsTest, ConcatFromSequence_Concat_Axis0) {
|
||||
OpTester test("ConcatFromSequence", 11);
|
||||
test.AddAttribute<int64_t>("axis", 0);
|
||||
test.AddAttribute<int64_t>("new_axis", 0); // concat mode
|
||||
SeqTensors<float> input;
|
||||
input.AddTensor({1, 2}, {0.0f, 1.0f});
|
||||
input.AddTensor({1, 2}, {2.0f, 3.0f});
|
||||
test.AddSeqInput("S", input);
|
||||
test.AddOutput<float>("I", {2, 2}, {0.0f, 1.0f, 2.0f, 3.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(SequenceOpsTest, ConcatFromSequence_Concat_Axis1) {
|
||||
OpTester test("ConcatFromSequence", 11);
|
||||
test.AddAttribute<int64_t>("axis", 1);
|
||||
test.AddAttribute<int64_t>("new_axis", 0); // concat mode
|
||||
SeqTensors<int32_t> input;
|
||||
input.AddTensor({1, 2}, {0, 1});
|
||||
input.AddTensor({1, 2}, {2, 3});
|
||||
test.AddSeqInput("S", input);
|
||||
test.AddOutput<int32_t>("I", {1, 4}, {0, 1, 2, 3});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(SequenceOpsTest, ConcatFromSequence_Concat_Axis2) {
|
||||
OpTester test("ConcatFromSequence", 11);
|
||||
test.AddAttribute<int64_t>("axis", 2);
|
||||
test.AddAttribute<int64_t>("new_axis", 0); // concat mode
|
||||
SeqTensors<int64_t> input;
|
||||
input.AddTensor({1, 2}, {0, 1});
|
||||
input.AddTensor({1, 2}, {2, 3});
|
||||
test.AddSeqInput("S", input);
|
||||
test.AddOutput<int64_t>("I", {1, 2, 2}, {0, 2, 1, 3});
|
||||
test.Run(OpTester::ExpectResult::kExpectFailure, "axis 2 is not in valid range [-2,1]");
|
||||
}
|
||||
|
||||
TEST(SequenceOpsTest, ConcatFromSequence_Concat_Axis1_WithEmptyInput) {
|
||||
OpTester test("ConcatFromSequence", 11);
|
||||
test.AddAttribute<int64_t>("axis", 1);
|
||||
test.AddAttribute<int64_t>("new_axis", 0); // concat mode
|
||||
SeqTensors<int64_t> input;
|
||||
input.AddTensor({1, 0}, {});
|
||||
input.AddTensor({1, 0}, {});
|
||||
input.AddTensor({1, 0}, {});
|
||||
test.AddSeqInput("S", input);
|
||||
test.AddOutput<int64_t>("I", {1, 0}, {});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(SequenceOpsTest, ConcatFromSequence_Concat_ScalarInputs) {
|
||||
OpTester test("ConcatFromSequence", 11);
|
||||
test.AddAttribute<int64_t>("axis", 0);
|
||||
test.AddAttribute<int64_t>("new_axis", 0); // concat mode
|
||||
SeqTensors<int64_t> input;
|
||||
input.AddTensor({}, {1});
|
||||
input.AddTensor({}, {2});
|
||||
input.AddTensor({}, {3});
|
||||
test.AddSeqInput("S", input);
|
||||
test.AddOutput<int64_t>("I", {3}, {1, 2, 3});
|
||||
test.Run(OpTester::ExpectResult::kExpectFailure,
|
||||
"Cannot concatenate scalars");
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue