mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Add float16 type support to SplitToSequence and make code type independent (#18594)
### Description Add support for `float16` type to address the below issue. Re-work the code to make it type independent. This reduces binary size by ~11 K.  ### Motivation and Context This PR addresses https://github.com/microsoft/onnxruntime/issues/18481
This commit is contained in:
parent
68209307da
commit
d2dfbf4179
4 changed files with 112 additions and 81 deletions
|
|
@ -373,7 +373,7 @@ Do not modify directly.*
|
|||
|||[13, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||[2, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|SplitToSequence|*in* input:**T**<br> *in* split:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))<br/> **T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string)|
|
||||
|SplitToSequence|*in* input:**T**<br> *in* split:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(string)|
|
||||
|Sqrt|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
|
||||
|||[6, 12]|**T** = tensor(double), tensor(float)|
|
||||
|Squeeze|*in* data:**T**<br> *in* axes:**tensor(int64)**<br> *out* squeezed:**T**<br><br>or<br><br>*in* data:**T**<br> *out* squeezed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|
|
|
|||
|
|
@ -334,27 +334,14 @@ Status SequenceConstruct::Compute(OpKernelContext* context) const {
|
|||
|
||||
// SplitToSequence
|
||||
|
||||
namespace op_kernel_type_control {
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, SplitToSequence, Input, 0,
|
||||
float, double, int32_t, int64_t, std::string);
|
||||
} // namespace op_kernel_type_control
|
||||
|
||||
namespace {
|
||||
using EnabledSplitToSequenceDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, SplitToSequence, Input, 0);
|
||||
} // namespace
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
SplitToSequence,
|
||||
11,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T",
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledSplitToSequenceDataTypes>())
|
||||
BuildKernelDefConstraints<float, MLFloat16, double, int32_t, int64_t, std::string>())
|
||||
.TypeConstraint("S", DataTypeImpl::AllSequenceTensorTypes())
|
||||
.TypeConstraint("I", std::vector<MLDataType>{
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
.TypeConstraint("I", BuildKernelDefConstraints<int32_t, int64_t>()),
|
||||
SplitToSequence);
|
||||
|
||||
SplitToSequence::SplitToSequence(const OpKernelInfo& info) : OpKernel(info) {
|
||||
|
|
@ -366,29 +353,14 @@ Status SplitToSequence::Compute(OpKernelContext* context) const {
|
|||
const Tensor& input = *context->Input<Tensor>(0);
|
||||
const Tensor* p_split_input = context->Input<Tensor>(1);
|
||||
|
||||
Status status;
|
||||
|
||||
if (input.IsDataType<float>())
|
||||
status = ComputeImpl<float>(*context, input, p_split_input);
|
||||
else if (input.IsDataType<double>())
|
||||
status = ComputeImpl<double>(*context, input, p_split_input);
|
||||
else if (input.IsDataType<int32_t>())
|
||||
status = ComputeImpl<int32_t>(*context, input, p_split_input);
|
||||
else if (input.IsDataType<int64_t>())
|
||||
status = ComputeImpl<int64_t>(*context, input, p_split_input);
|
||||
else if (input.IsDataTypeString())
|
||||
status = ComputeImpl<std::string>(*context, input, p_split_input);
|
||||
else
|
||||
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "SplitToSequence operator does not support ", input.DataType(), " yet");
|
||||
|
||||
return status;
|
||||
return ComputeImpl(*context, input, p_split_input);
|
||||
}
|
||||
|
||||
Status SplitToSequence::PrepareForCompute(const TensorShape& input_shape, int64_t split_scalar, bool is_split_input_scalar,
|
||||
int64_t& num_outputs, int64_t& axis, int& before_dims,
|
||||
int& after_dims_including_split_axis, int& after_dims_excluding_split,
|
||||
bool& is_uneven_split, int& num_remaining_splits,
|
||||
std::vector<int64_t>& split_sizes) const {
|
||||
InlinedVector<int64_t>& split_sizes) const {
|
||||
auto input_dims = input_shape.GetDims();
|
||||
const auto num_dimensions = gsl::narrow_cast<int64_t>(input_shape.NumDimensions());
|
||||
axis = HandleNegativeAxis(axis_, num_dimensions); // handle negative and enforce axis is valid
|
||||
|
|
@ -416,7 +388,7 @@ Status SplitToSequence::PrepareForCompute(const TensorShape& input_shape, int64_
|
|||
// populate split_sizes with the same size for each output
|
||||
num_outputs = split_dim_size;
|
||||
// https://github.com/onnx/onnx/issues/2396
|
||||
split_sizes = std::vector<int64_t>(static_cast<size_t>(num_outputs), DEFAULT_LENGTH_EACH_OUTPUT_);
|
||||
split_sizes = InlinedVector<int64_t>(static_cast<size_t>(num_outputs), DEFAULT_LENGTH_EACH_OUTPUT_);
|
||||
} else {
|
||||
auto split_size_sum = std::accumulate(split_sizes.cbegin(), split_sizes.cend(), 0LL);
|
||||
if (split_size_sum != split_dim_size) {
|
||||
|
|
@ -453,7 +425,7 @@ static int64_t GetScalarSplitInput(const Tensor& tensor) {
|
|||
return retval;
|
||||
}
|
||||
|
||||
static void GetSplitSizesInput(const Tensor& tensor, std::vector<int64_t>& split_sizes) {
|
||||
static void GetSplitSizesInput(const Tensor& tensor, InlinedVector<int64_t>& split_sizes) {
|
||||
auto num_elems = tensor.Shape().Size();
|
||||
split_sizes.reserve(onnxruntime::narrow<size_t>(num_elems));
|
||||
if (tensor.IsDataType<int32_t>()) {
|
||||
|
|
@ -467,13 +439,8 @@ static void GetSplitSizesInput(const Tensor& tensor, std::vector<int64_t>& split
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& input,
|
||||
const Tensor* p_split_input) const {
|
||||
if (!utils::HasType<EnabledSplitToSequenceDataTypes, T>()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Data type is not supported in this build.");
|
||||
}
|
||||
|
||||
auto& input_shape = input.Shape();
|
||||
int64_t num_outputs = 0;
|
||||
int64_t axis = axis_;
|
||||
|
|
@ -484,7 +451,9 @@ Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& inpu
|
|||
bool is_split_input_scalar = false;
|
||||
bool is_uneven_split = false;
|
||||
int num_remaining_splits = 0;
|
||||
std::vector<int64_t> split_sizes;
|
||||
InlinedVector<int64_t> split_sizes;
|
||||
const bool is_string_type = input.IsDataTypeString();
|
||||
const size_t element_size = (is_string_type) ? 0U : input.DataType()->Size();
|
||||
|
||||
// figure out split_scalar or split_sizes
|
||||
if (p_split_input) {
|
||||
|
|
@ -520,8 +489,8 @@ Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& inpu
|
|||
|
||||
// copy dimensions so we can update the selected axis in place
|
||||
auto output_dimensions = input_shape.AsShapeVector();
|
||||
int64_t input_offset = 0;
|
||||
const T* input_data = input.Data<T>();
|
||||
SafeInt<size_t> input_offset = 0;
|
||||
const void* input_data = input.DataRaw();
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
// update size of dimension for axis we're splitting on while considering uneven split
|
||||
int split_size;
|
||||
|
|
@ -535,20 +504,50 @@ Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& inpu
|
|||
AllocatorPtr alloc;
|
||||
ORT_RETURN_IF_ERROR(context.GetTempSpaceAllocator(&alloc));
|
||||
Tensor output_tensor(input.DataType(), onnxruntime::TensorShape(output_dimensions), alloc);
|
||||
T* output_data = output_tensor.MutableData<T>();
|
||||
void* output_data = output_tensor.MutableDataRaw();
|
||||
|
||||
::onnxruntime::math::CopyMatrix<T>(
|
||||
before_dims, // M
|
||||
split_size * after_dims_excluding_split, // N
|
||||
static_cast<const T*>(input_data + input_offset), // A
|
||||
after_dims_including_split_axis, // lda
|
||||
static_cast<T*>(output_data), // B
|
||||
split_size * after_dims_excluding_split, // ldb
|
||||
[](const T* src, T* dst, size_t count) {
|
||||
copy_data<T>(src, dst, count);
|
||||
});
|
||||
const auto M = before_dims;
|
||||
const auto* A = static_cast<const char*>(input_data) + static_cast<size_t>(input_offset * element_size);
|
||||
const auto lda = after_dims_including_split_axis;
|
||||
auto* B = output_data;
|
||||
|
||||
input_offset += static_cast<int64_t>(split_size) * after_dims_excluding_split; // offset by the N data we used in this iteration
|
||||
const auto N = split_size * after_dims_excluding_split;
|
||||
const auto ldb = N;
|
||||
|
||||
if (is_string_type) {
|
||||
const auto* src = reinterpret_cast<const std::string*>(A);
|
||||
auto* dst = reinterpret_cast<std::string*>(B);
|
||||
if (lda == N) {
|
||||
copy_data<std::string>(src, dst, static_cast<size_t>(M * N));
|
||||
} else {
|
||||
size_t lda_offset = 0;
|
||||
size_t ldb_offset = 0;
|
||||
for (size_t idx = 0; idx < static_cast<size_t>(M); ++idx,
|
||||
lda_offset += lda, ldb_offset += ldb) {
|
||||
copy_data<std::string>(src + lda_offset, dst + ldb_offset, static_cast<size_t>(N));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (lda == N) {
|
||||
// if the data is contiguous, we can just copy the data
|
||||
const size_t bytes_to_copy = static_cast<size_t>(N) * static_cast<size_t>(M) * element_size;
|
||||
memcpy(B, A, bytes_to_copy);
|
||||
} else {
|
||||
// otherwise we need to copy each row
|
||||
const size_t row_bytes = SafeInt<size_t>(N) * element_size;
|
||||
const auto lda_bytes_inc = SafeInt<size_t>(lda) * element_size;
|
||||
const auto ldb_bytes_inc = SafeInt<size_t>(ldb) * element_size;
|
||||
SafeInt<size_t> lda_bytes_offset = 0;
|
||||
SafeInt<size_t> ldb_bytes_offset = 0;
|
||||
for (size_t idx = 0; idx < static_cast<size_t>(M); ++idx,
|
||||
lda_bytes_offset += lda_bytes_inc, ldb_bytes_offset += ldb_bytes_inc) {
|
||||
memcpy(reinterpret_cast<char*>(B) + static_cast<size_t>(ldb_bytes_offset),
|
||||
reinterpret_cast<const char*>(A) + static_cast<size_t>(lda_bytes_offset), row_bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
input_offset += SafeInt<size_t>(split_size) * after_dims_excluding_split; // offset by the N data we used in this iteration
|
||||
|
||||
// if keep_dims = 0, reshape the tensor by dropping the dimension corresponding to 'axis'
|
||||
if (use_keep_dims && keepdims_ == 0) {
|
||||
|
|
|
|||
|
|
@ -60,13 +60,12 @@ class SplitToSequence final : public OpKernel {
|
|||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
Status ComputeImpl(OpKernelContext& context, const Tensor& input, const Tensor* p_split_input) const;
|
||||
Status PrepareForCompute(const TensorShape& input_shape, int64_t split_scalar, bool is_split_input_scalar,
|
||||
int64_t& num_outputs, int64_t& axis, int& before_dims,
|
||||
int& after_dims_including_split_axis, int& after_dims_excluding_split,
|
||||
bool& is_uneven_split, int& num_remaining_splits,
|
||||
std::vector<int64_t>& split_sizes) const;
|
||||
InlinedVector<int64_t>& split_sizes) const;
|
||||
int64_t axis_{};
|
||||
int64_t keepdims_{1};
|
||||
const int64_t DEFAULT_LENGTH_EACH_OUTPUT_ = 1;
|
||||
|
|
|
|||
|
|
@ -330,15 +330,26 @@ TEST(SequenceOpsTest, SequenceConstructPositive) {
|
|||
|
||||
// SplitToSequence
|
||||
template <typename T>
|
||||
static std::vector<T> GetConsequtiveVector(T start, int num) {
|
||||
static std::vector<T> GetConsecutiveVector(T start, size_t num) {
|
||||
std::vector<T> inputv(num);
|
||||
std::iota(inputv.begin(), inputv.end(), start);
|
||||
return inputv;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<MLFloat16> GetConsecutiveVector<MLFloat16>(MLFloat16 start, size_t num) {
|
||||
std::vector<MLFloat16> inputv;
|
||||
inputv.reserve(num);
|
||||
float start_f = start.ToFloat();
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
inputv.push_back(MLFloat16{start_f + static_cast<float>(i)});
|
||||
}
|
||||
return inputv;
|
||||
}
|
||||
|
||||
TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0EqualSplitFloat) {
|
||||
OpTester test("SplitToSequence", 11);
|
||||
test.AddInput<float>("input", {4, 2}, GetConsequtiveVector<float>(1.f, 8));
|
||||
test.AddInput<float>("input", {4, 2}, GetConsecutiveVector<float>(1.f, 8));
|
||||
test.AddInput<int64_t>("split", {1, 2}, {2, 2});
|
||||
SeqTensors<float> output;
|
||||
output.AddTensor({2, 2}, {1.f, 2.f, 3.f, 4.f});
|
||||
|
|
@ -347,9 +358,31 @@ TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0EqualSplitFloat) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0EqualSplitMLFloat16) {
|
||||
OpTester test("SplitToSequence", 11);
|
||||
test.AddInput<MLFloat16>("input", {4, 2}, GetConsecutiveVector<MLFloat16>(MLFloat16::One, 8));
|
||||
test.AddInput<int64_t>("split", {1, 2}, {2, 2});
|
||||
SeqTensors<MLFloat16> output;
|
||||
|
||||
std::vector<MLFloat16> tensor_1;
|
||||
const auto data_1 = {1.f, 2.f, 3.f, 4.f};
|
||||
for (auto f : data_1)
|
||||
tensor_1.push_back(MLFloat16{f});
|
||||
|
||||
std::vector<MLFloat16> tensor_2;
|
||||
const auto data_2 = {5.f, 6.f, 7.f, 8.f};
|
||||
for (auto f : data_2)
|
||||
tensor_2.push_back(MLFloat16{f});
|
||||
|
||||
output.AddTensor({2, 2}, tensor_1);
|
||||
output.AddTensor({2, 2}, tensor_2);
|
||||
test.AddSeqOutput("S2", output);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0EqualSplitLong) {
|
||||
OpTester test("SplitToSequence", 11);
|
||||
test.AddInput<int64_t>("input", {4, 2}, GetConsequtiveVector<int64_t>(1, 8));
|
||||
test.AddInput<int64_t>("input", {4, 2}, GetConsecutiveVector<int64_t>(1, 8));
|
||||
test.AddInput<int64_t>("split", {1, 2}, {2, 2});
|
||||
SeqTensors<int64_t> output;
|
||||
output.AddTensor({2, 2}, {1, 2, 3, 4});
|
||||
|
|
@ -360,7 +393,7 @@ TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0EqualSplitLong) {
|
|||
|
||||
TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0EqualSplitFloatScalarSplit) {
|
||||
OpTester test("SplitToSequence", 11);
|
||||
test.AddInput<float>("input", {4, 2}, GetConsequtiveVector<float>(1.f, 8));
|
||||
test.AddInput<float>("input", {4, 2}, GetConsecutiveVector<float>(1.f, 8));
|
||||
test.AddInput<int64_t>("split", {}, {2});
|
||||
SeqTensors<float> output;
|
||||
output.AddTensor({2, 2}, {1.f, 2.f, 3.f, 4.f});
|
||||
|
|
@ -371,7 +404,7 @@ TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0EqualSplitFloatScalarSplit) {
|
|||
|
||||
TEST(SequenceOpsTest, SplitToSequence_Axis0DefaultSplitFloatSetAxisExplicitly) {
|
||||
OpTester test("SplitToSequence", 11);
|
||||
test.AddInput<float>("input", {4, 2}, GetConsequtiveVector<float>(1.f, 8));
|
||||
test.AddInput<float>("input", {4, 2}, GetConsecutiveVector<float>(1.f, 8));
|
||||
int64_t axis = 0;
|
||||
test.AddAttribute("axis", axis);
|
||||
SeqTensors<float> output;
|
||||
|
|
@ -385,7 +418,7 @@ TEST(SequenceOpsTest, SplitToSequence_Axis0DefaultSplitFloatSetAxisExplicitly) {
|
|||
|
||||
TEST(SequenceOpsTest, SplitToSequence_PositiveAxisScalarSplit) {
|
||||
OpTester test("SplitToSequence", 11);
|
||||
test.AddInput<float>("input", {2, 2, 6}, GetConsequtiveVector<float>(1.f, 2 * 2 * 6));
|
||||
test.AddInput<float>("input", {2, 2, 6}, GetConsecutiveVector<float>(1.f, 2 * 2 * 6));
|
||||
int64_t axis = 2;
|
||||
test.AddAttribute("axis", axis);
|
||||
test.AddInput<int64_t>("split", {}, {2});
|
||||
|
|
@ -411,11 +444,11 @@ TEST(SequenceOpsTest, SplitToSequence_PositiveAxisScalarSplit) {
|
|||
|
||||
TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0UnevenSplitFloat) {
|
||||
OpTester test("SplitToSequence", 11);
|
||||
test.AddInput<float>("input", {5, 2}, GetConsequtiveVector<float>(1.f, 10));
|
||||
test.AddInput<float>("input", {5, 2}, GetConsecutiveVector<float>(1.f, 10));
|
||||
test.AddInput<int64_t>("split", {}, {2});
|
||||
SeqTensors<float> output;
|
||||
output.AddTensor({2, 2}, GetConsequtiveVector<float>(1.f, 4));
|
||||
output.AddTensor({2, 2}, GetConsequtiveVector<float>(5.f, 4));
|
||||
output.AddTensor({2, 2}, GetConsecutiveVector<float>(1.f, 4));
|
||||
output.AddTensor({2, 2}, GetConsecutiveVector<float>(5.f, 4));
|
||||
output.AddTensor({1, 2}, {9.f, 10.f});
|
||||
test.AddSeqOutput("S2", output);
|
||||
test.Run();
|
||||
|
|
@ -423,22 +456,22 @@ TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0UnevenSplitFloat) {
|
|||
|
||||
TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0UnevenSplitFloat2) {
|
||||
OpTester test("SplitToSequence", 11);
|
||||
test.AddInput<float>("input", {17, 2}, GetConsequtiveVector<float>(1.f, 34));
|
||||
test.AddInput<float>("input", {17, 2}, GetConsecutiveVector<float>(1.f, 34));
|
||||
test.AddInput<int64_t>("split", {}, {3});
|
||||
SeqTensors<float> output;
|
||||
output.AddTensor({3, 2}, GetConsequtiveVector<float>(1.f, 6));
|
||||
output.AddTensor({3, 2}, GetConsequtiveVector<float>(7.f, 6));
|
||||
output.AddTensor({3, 2}, GetConsequtiveVector<float>(13.f, 6));
|
||||
output.AddTensor({3, 2}, GetConsequtiveVector<float>(19.f, 6));
|
||||
output.AddTensor({3, 2}, GetConsequtiveVector<float>(25.f, 6));
|
||||
output.AddTensor({2, 2}, GetConsequtiveVector<float>(31.f, 4));
|
||||
output.AddTensor({3, 2}, GetConsecutiveVector<float>(1.f, 6));
|
||||
output.AddTensor({3, 2}, GetConsecutiveVector<float>(7.f, 6));
|
||||
output.AddTensor({3, 2}, GetConsecutiveVector<float>(13.f, 6));
|
||||
output.AddTensor({3, 2}, GetConsecutiveVector<float>(19.f, 6));
|
||||
output.AddTensor({3, 2}, GetConsecutiveVector<float>(25.f, 6));
|
||||
output.AddTensor({2, 2}, GetConsecutiveVector<float>(31.f, 4));
|
||||
test.AddSeqOutput("S2", output);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(SequenceOpsTest, SplitToSequence_PositiveAxisUnevenSplit) {
|
||||
OpTester test("SplitToSequence", 11);
|
||||
test.AddInput<float>("input", {2, 5}, GetConsequtiveVector<float>(1.f, 10));
|
||||
test.AddInput<float>("input", {2, 5}, GetConsecutiveVector<float>(1.f, 10));
|
||||
test.AddInput<int64_t>("split", {}, {2});
|
||||
int64_t axis = 1;
|
||||
test.AddAttribute("axis", axis);
|
||||
|
|
@ -452,33 +485,33 @@ TEST(SequenceOpsTest, SplitToSequence_PositiveAxisUnevenSplit) {
|
|||
|
||||
TEST(SequenceOpsTest, SplitToSequence_Axis0DefaultSplitFloatSetAxisExplicitlyDontKeepDims3Dim) {
|
||||
OpTester test("SplitToSequence", 11);
|
||||
test.AddInput<float>("input", {2, 3, 4}, GetConsequtiveVector<float>(1.f, 2 * 3 * 4));
|
||||
test.AddInput<float>("input", {2, 3, 4}, GetConsecutiveVector<float>(1.f, 2 * 3 * 4));
|
||||
test.AddAttribute<int64_t>("keepdims", 0);
|
||||
int64_t axis = 0;
|
||||
test.AddAttribute("axis", axis);
|
||||
SeqTensors<float> output;
|
||||
output.AddTensor({3, 4}, GetConsequtiveVector<float>(1.f, 12));
|
||||
output.AddTensor({3, 4}, GetConsequtiveVector<float>(13.f, 12));
|
||||
output.AddTensor({3, 4}, GetConsecutiveVector<float>(1.f, 12));
|
||||
output.AddTensor({3, 4}, GetConsecutiveVector<float>(13.f, 12));
|
||||
test.AddSeqOutput("S2", output);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(SequenceOpsTest, SplitToSequence_Axis0DefaultSplitFloatSetAxisExplicitlyDontKeepDims2Dim) {
|
||||
OpTester test("SplitToSequence", 11);
|
||||
test.AddInput<float>("input", {2, 3}, GetConsequtiveVector<float>(1.f, 2 * 3));
|
||||
test.AddInput<float>("input", {2, 3}, GetConsecutiveVector<float>(1.f, 2 * 3));
|
||||
test.AddAttribute<int64_t>("keepdims", 0);
|
||||
int64_t axis = 0;
|
||||
test.AddAttribute("axis", axis);
|
||||
SeqTensors<float> output;
|
||||
output.AddTensor({3}, GetConsequtiveVector<float>(1.f, 3));
|
||||
output.AddTensor({3}, GetConsequtiveVector<float>(4.f, 3));
|
||||
output.AddTensor({3}, GetConsecutiveVector<float>(1.f, 3));
|
||||
output.AddTensor({3}, GetConsecutiveVector<float>(4.f, 3));
|
||||
test.AddSeqOutput("S2", output);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(SequenceOpsTest, SplitToSequence_PositiveAxisDontKeepDims) {
|
||||
OpTester test("SplitToSequence", 11);
|
||||
test.AddInput<float>("input", {2, 3, 4}, GetConsequtiveVector<float>(1.f, 2 * 3 * 4));
|
||||
test.AddInput<float>("input", {2, 3, 4}, GetConsecutiveVector<float>(1.f, 2 * 3 * 4));
|
||||
test.AddAttribute<int64_t>("keepdims", 0);
|
||||
int64_t axis = 2;
|
||||
test.AddAttribute("axis", axis);
|
||||
|
|
|
|||
Loading…
Reference in a new issue