diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 3ab3c0deca..6860f270c8 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -329,9 +329,9 @@ Do not modify directly.*
|Softsign|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)|
|SpaceToDepth|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)|
|||[1, 12]|**T** = tensor(double), tensor(float)|
-|Split|*in* input:**T**
*in* split:**T**
*out* outputs...:**T**
or
*in* input:**T**
*in* split:**tensor(int64)**
*out* outputs:**T**
or
*in* input:**T**
*out* outputs:**T**|13+|**T** = tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint8)|
-|||[11, 12]|**T** = tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint8)|
-|||[2, 10]|**T** = tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint8)|
+|Split|*in* input:**T**
*in* split:**T**
*out* outputs...:**T**
or
*in* input:**T**
*in* split:**tensor(int64)**
*out* outputs:**T**
or
*in* input:**T**
*out* outputs:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[2, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|SplitToSequence|*in* input:**T**
*in* split:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string)|
|Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
|||[6, 12]|**T** = tensor(double), tensor(float)|
diff --git a/onnxruntime/core/framework/copy.h b/onnxruntime/core/framework/copy.h
index d961aef6b1..0ba1f2a0a9 100644
--- a/onnxruntime/core/framework/copy.h
+++ b/onnxruntime/core/framework/copy.h
@@ -122,19 +122,23 @@ void StridedCopy(concurrency::ThreadPool* thread_pool,
!copy_shape.empty(),
"src and dst must have same shape and not be rank 0.");
- const std::size_t dims = copy_shape.size();
- // We will iterate over the output dimensions
- int64_t num_iterations = 1;
- for (std::size_t dim = 0; dim < dims; dim++) {
- num_iterations *= copy_shape[dim];
+ const int64_t total_num_elements_to_copy = copy_shape_in.Size();
+
+ ORT_ENFORCE(total_num_elements_to_copy >= 0, "copy shape must have non-negative size");
+
+ if (total_num_elements_to_copy == 0) {
+ // empty edge case
+ return;
}
- if (num_iterations <= 1) {
+ if (total_num_elements_to_copy == 1) {
// scalar edge case
dst[0] = src[0];
return;
}
+ const std::size_t dims = copy_shape.size();
+
// TODOs for when we have strided tensors:
// - Reorder dimensions so that we iterate along the smallest strides first
@@ -151,7 +155,7 @@ void StridedCopy(concurrency::ThreadPool* thread_pool,
std::ptrdiff_t contiguous_span_size = static_cast(dims == 2 ? copy_shape[1] : copy_shape[0]);
concurrency::ThreadPool::TryParallelFor(
- thread_pool, static_cast(num_iterations),
+ thread_pool, static_cast(total_num_elements_to_copy),
{static_cast(sizeof(T)), static_cast(sizeof(T)), 1.0F},
[src_stride, dst_stride, dst, src, contiguous_span_size](std::ptrdiff_t first, std::ptrdiff_t last) {
// get the current inner and outer index
@@ -196,7 +200,7 @@ void StridedCopy(concurrency::ThreadPool* thread_pool,
const TensorShapeVector& const_copy_shape = copy_shape;
concurrency::ThreadPool::TryParallelFor(
- thread_pool, static_cast(num_iterations),
+ thread_pool, static_cast(total_num_elements_to_copy),
{static_cast(sizeof(T)), static_cast(sizeof(T)), 1.0F},
[&const_copy_shape, &const_dst_strides, dst, src, &const_src_strides, dims](std::ptrdiff_t first,
std::ptrdiff_t last) {
@@ -234,6 +238,7 @@ inline bool StridedCopyIfEnabled(concurrency::ThreadPool* thread_pool,
const TensorShapeVector& dst_strides,
const TensorShape& copy_shape,
const Tensor& src,
+ std::ptrdiff_t src_offset,
const TensorShapeVector& src_strides) {
constexpr bool enabled = utils::HasTypeWithSameSize();
if constexpr (enabled) {
@@ -242,7 +247,7 @@ inline bool StridedCopyIfEnabled(concurrency::ThreadPool* thread_pool,
StridedCopy(thread_pool,
reinterpret_cast(dst.MutableDataRaw()) + dst_offset,
dst_strides, copy_shape,
- reinterpret_cast(src.DataRaw()),
+ reinterpret_cast(src.DataRaw()) + src_offset,
src_strides);
}
@@ -259,40 +264,41 @@ Status DispatchStridedCopy(concurrency::ThreadPool* thread_pool,
const TensorShapeVector& dst_strides,
const TensorShape& copy_shape,
const Tensor& src,
+ std::ptrdiff_t src_offset,
const TensorShapeVector& src_strides) {
ORT_ENFORCE(dst.DataType() == src.DataType(), "src and dst types must match");
bool supported = false;
if (src.IsDataTypeString()) {
- if (utils::HasType()) {
+ if constexpr (utils::HasType()) {
supported = true;
StridedCopy(thread_pool, dst.MutableData() + dst_offset, dst_strides, copy_shape,
- src.Data(), src_strides);
+ src.Data() + src_offset, src_strides);
}
} else {
const auto element_size = src.DataType()->Size();
switch (element_size) {
case sizeof(uint32_t):
supported = StridedCopyIfEnabled(thread_pool, dst, dst_offset, dst_strides,
- copy_shape, src, src_strides);
+ copy_shape, src, src_offset, src_strides);
break;
case sizeof(uint64_t):
supported = StridedCopyIfEnabled(thread_pool, dst, dst_offset, dst_strides,
- copy_shape, src, src_strides);
+ copy_shape, src, src_offset, src_strides);
break;
case sizeof(uint16_t):
supported = StridedCopyIfEnabled(thread_pool, dst, dst_offset, dst_strides,
- copy_shape, src, src_strides);
+ copy_shape, src, src_offset, src_strides);
break;
case sizeof(uint8_t):
static_assert(sizeof(bool) == sizeof(uint8_t), "Need to enable separate case for 'bool' on this platform.");
supported = StridedCopyIfEnabled(thread_pool, dst, dst_offset, dst_strides,
- copy_shape, src, src_strides);
+ copy_shape, src, src_offset, src_strides);
break;
// It's possible that bool is not 1 byte. static_assert above checks if we need to enable this on a platform.
- //case sizeof(bool):
+ // case sizeof(bool):
// supported = StridedCopyIfEnabled(thread_pool, dst, dst_offset, dst_strides,
- // copy_shape, src, src_strides);
+ // copy_shape, src, src_offset, src_strides);
// break;
default:
// leave 'supported' as false
diff --git a/onnxruntime/core/framework/data_transfer.cc b/onnxruntime/core/framework/data_transfer.cc
index 4489f6f7c9..7f4f522248 100644
--- a/onnxruntime/core/framework/data_transfer.cc
+++ b/onnxruntime/core/framework/data_transfer.cc
@@ -56,8 +56,10 @@ common::Status CPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int /
auto src_stride_vec = src.Strides();
onnxruntime::TensorShapeVector dst_stride{dst_stride_vec.begin(), dst_stride_vec.end()};
onnxruntime::TensorShapeVector src_stride{src_stride_vec.begin(), src_stride_vec.end()};
- return DispatchStridedCopy(nullptr, dst, dst.ByteOffset(), dst_stride, src.Shape(), src,
- src_stride);
+ return DispatchStridedCopy(nullptr,
+ dst, 0, dst_stride,
+ src.Shape(),
+ src, 0, src_stride);
} else {
#endif
// Copying only happens between two same size tensors.
diff --git a/onnxruntime/core/providers/cpu/tensor/concat.cc b/onnxruntime/core/providers/cpu/tensor/concat.cc
index 1ff51caad7..13bcc54100 100644
--- a/onnxruntime/core/providers/cpu/tensor/concat.cc
+++ b/onnxruntime/core/providers/cpu/tensor/concat.cc
@@ -266,10 +266,11 @@ Status ConcatBase::ComputeImpl(Prepare& p, OpKernelContext* ctx) const {
// parallel copy the data across
auto status = DispatchStridedCopy(ctx->GetOperatorThreadPool(),
*p.output_tensor,
- onnxruntime::narrow(initial_output_offset),
+ onnxruntime::narrow(initial_output_offset),
output_strides_for_copy,
prep.tensor->Shape(),
*prep.tensor,
+ 0, // src_offset
StridesForTensor(*prep.tensor));
ORT_RETURN_IF_ERROR(status);
diff --git a/onnxruntime/core/providers/cpu/tensor/split.cc b/onnxruntime/core/providers/cpu/tensor/split.cc
index ca662c1999..5b73e8e68c 100644
--- a/onnxruntime/core/providers/cpu/tensor/split.cc
+++ b/onnxruntime/core/providers/cpu/tensor/split.cc
@@ -1,12 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#include
#include "core/providers/cpu/tensor/split.h"
-#include "core/common/gsl.h"
-
#include "core/common/narrow.h"
+#include "core/common/gsl.h"
+#include "core/common/safeint.h"
+#include "core/framework/copy.h"
+#include "core/framework/element_type_lists.h"
#include "core/framework/op_kernel_type_control_utils.h"
#include "core/providers/common.h"
#include "core/providers/op_kernel_type_control.h"
@@ -16,9 +17,9 @@
namespace onnxruntime {
namespace op_kernel_type_control {
-ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
+ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Split, Input, 0,
- float, int8_t, int32_t, int64_t, uint8_t, std::string);
+ element_type_lists::All);
ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Split, Input, 0,
int32_t, int64_t);
@@ -93,63 +94,25 @@ Status SplitBase::PrepareForCompute(const TensorShape& input_shape, int num_outp
Status Split::Compute(OpKernelContext* context) const {
const Tensor& input = *context->Input(0);
-
- Status status;
-
- // Note: The non-string implementations can probably be based on data type size.
- if (input.IsDataType())
- status = ComputeImpl(*context, input);
- else if (input.IsDataType())
- status = ComputeImpl(*context, input);
- else if (input.IsDataType())
- status = ComputeImpl(*context, input);
- else if (input.IsDataType())
- status = ComputeImpl(*context, input);
- else if (input.IsDataType())
- status = ComputeImpl(*context, input);
- else if (input.IsDataTypeString())
- status = ComputeImpl(*context, input);
- else
- ORT_THROW("Split operator does not support ", input.DataType(), " yet");
-
- return status;
-}
-
-template
-inline void copy_data(const T* src, T* dst, size_t count) {
- memcpy(dst, src, count * sizeof(T));
-}
-
-template <>
-inline void copy_data(const std::string* src, std::string* dst, size_t count) {
- const std::string* end = src + count;
- std::copy(src, end, dst);
-}
-
-template
-Status Split::ComputeImpl(OpKernelContext& context, const Tensor& input) const {
- if (!utils::HasType()) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Data type is not supported in this build.");
- }
-
auto& input_shape = input.Shape();
- auto num_outputs = context.OutputCount();
+ auto num_outputs = context->OutputCount();
int64_t axis = axis_;
int before_dims = 0;
int after_dims_including_split_axis = 0;
int after_dims_excluding_split = 0;
std::vector split_sizes;
- const Tensor* split_tensor = context.Input(1);
+ const Tensor* split_tensor = context->Input(1);
if (split_tensor != nullptr) {
- //override the attribute value with the input value for split
- ORT_ENFORCE(split_tensor->Shape().NumDimensions() == 1, "An split tensor must be a vector tensor.");
+ // override the attribute value with the input value for split
+ ORT_ENFORCE(split_tensor->Shape().NumDimensions() == 1, "The split tensor must be a vector tensor.");
auto nDims = static_cast(split_tensor->Shape()[0]);
const auto* data = split_tensor->Data();
split_sizes.assign(data, data + nDims);
} else {
split_sizes.assign(split_sizes_.begin(), split_sizes_.end());
}
+
ORT_RETURN_IF_ERROR(PrepareForCompute(input_shape,
num_outputs,
axis,
@@ -158,32 +121,27 @@ Status Split::ComputeImpl(OpKernelContext& context, const Tensor& input) const {
after_dims_excluding_split,
split_sizes));
+ const auto input_strides = StridesForTensor(input);
+
// copy dimensions so we can update the selected axis in place
auto output_dimensions = input_shape.AsShapeVector();
- int64_t input_offset = 0;
- const T* input_data = input.Data();
+ SafeInt input_offset = 0;
for (int i = 0; i < num_outputs; ++i) {
// update size of dimension for axis we're splitting on
auto split_size = narrow(split_sizes[i]);
- output_dimensions[onnxruntime::narrow(axis)] = split_size;
+ output_dimensions[narrow(axis)] = split_size;
- Tensor* output = context.Output(i, TensorShape{output_dimensions});
- T* output_data = output->MutableData();
+ Tensor* output = context->Output(i, TensorShape{output_dimensions});
+ const auto output_strides = StridesForTensor(*output);
- ::onnxruntime::math::CopyMatrix(
- before_dims, // M
- split_size * after_dims_excluding_split, // N
- static_cast(input_data + input_offset), // A
- after_dims_including_split_axis, // lda
- static_cast(output_data), // B
- split_size * after_dims_excluding_split, // ldb
- [](const T* src, T* dst, size_t count) {
- copy_data(src, dst, count);
- });
+ ORT_RETURN_IF_ERROR(DispatchStridedCopy(context->GetOperatorThreadPool(),
+ *output, /* dst_offset */ 0, output_strides,
+ output->Shape(),
+ input, input_offset, input_strides));
- input_offset += static_cast(split_size) * after_dims_excluding_split; // offset by the N data we used in this iteration
+ input_offset += SafeInt(split_size) * after_dims_excluding_split; // offset by the data we used in this iteration
}
return Status::OK();
diff --git a/onnxruntime/core/providers/cpu/tensor/split.h b/onnxruntime/core/providers/cpu/tensor/split.h
index e455cb9157..389e1f5af8 100644
--- a/onnxruntime/core/providers/cpu/tensor/split.h
+++ b/onnxruntime/core/providers/cpu/tensor/split.h
@@ -46,10 +46,6 @@ class Split final : public OpKernel, public SplitBase {
Split(const OpKernelInfo& info) : OpKernel(info), SplitBase(info) {}
Status Compute(OpKernelContext* context) const override;
-
- private:
- template
- Status ComputeImpl(OpKernelContext& context, const Tensor& input) const;
};
} // namespace onnxruntime
diff --git a/onnxruntime/test/framework/copy_test.cc b/onnxruntime/test/framework/copy_test.cc
index 009895657b..1fc670f8b9 100644
--- a/onnxruntime/test/framework/copy_test.cc
+++ b/onnxruntime/test/framework/copy_test.cc
@@ -144,17 +144,6 @@ TEST_F(CopyTest, CoalesceTensorsTest) {
ASSERT_THAT(strides_b, testing::ElementsAre(1));
ASSERT_THAT(shape, testing::ElementsAre(15));
}
- {
- TensorShapeVector strides_a{3, 3, 3, 1};
- TensorShapeVector strides_b{3, 3, 3, 1};
- TensorShapeVector shape{1, 5, 1, 3};
-
- CoalesceDimensions({strides_a, strides_b}, shape);
-
- ASSERT_THAT(strides_a, testing::ElementsAre(1));
- ASSERT_THAT(strides_b, testing::ElementsAre(1));
- ASSERT_THAT(shape, testing::ElementsAre(15));
- }
{
TensorShapeVector strides_a{320, 1};
TensorShapeVector strides_b{320, 1};
@@ -188,17 +177,6 @@ TEST_F(CopyTest, CoalesceTensorsTest) {
ASSERT_THAT(strides_b, testing::ElementsAre(6, 1));
ASSERT_THAT(shape, testing::ElementsAre(5, 3));
}
- {
- TensorShapeVector strides_a{3, 1};
- TensorShapeVector strides_b{6, 1};
- TensorShapeVector shape{5, 3};
-
- CoalesceDimensions({strides_a, strides_b}, shape);
-
- ASSERT_THAT(strides_a, testing::ElementsAre(3, 1));
- ASSERT_THAT(strides_b, testing::ElementsAre(6, 1));
- ASSERT_THAT(shape, testing::ElementsAre(5, 3));
- }
{
TensorShapeVector strides_a{4, 1};
TensorShapeVector strides_b{1, 1};
diff --git a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc
index c3b5b25ee7..bbc1f933b3 100644
--- a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc
@@ -2,11 +2,14 @@
// Licensed under the MIT License.
#include "gtest/gtest.h"
+#include "core/framework/to_tensor_proto_element_type.h"
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace test {
+namespace {
+
template
using ShapeAndData = std::pair, const std::vector>;
@@ -15,15 +18,27 @@ using ShapeAndStringData = ShapeAndData;
using ExpectResult = OpTester::ExpectResult;
template
-void RunTest(int64_t axis, const std::vector split_sizes, const ShapeAndData& input,
+void RunTest(int64_t axis, const std::vector& split_sizes, const ShapeAndData& input,
const std::vector>& outputs, bool is_tensorrt_supported = true,
bool expect_failure = false, bool split_as_input = false,
bool is_initializer = true, const std::string& err_msg = {}, bool skip_split_if_empty = true) {
int opset_version = split_as_input ? 13 : 7;
OpTester test("Split", opset_version, onnxruntime::kOnnxDomain);
+ constexpr bool is_bool_data = std::is_same_v;
+ [[maybe_unused]] auto bool_vector_to_array = [](const std::vector& v) -> std::unique_ptr {
+ auto a = std::make_unique(v.size());
+ std::copy(v.begin(), v.end(), a.get());
+ return a;
+ };
+
test.AddAttribute("axis", axis);
- test.AddInput("input", input.first, input.second);
+ if constexpr (is_bool_data) {
+ auto input_array = bool_vector_to_array(input.second);
+ test.AddInput("input", input.first, input_array.get(), input.second.size());
+ } else {
+ test.AddInput("input", input.first, input.second);
+ }
if (!split_sizes.empty()) {
if (split_as_input) {
test.AddInput("split", {static_cast(split_sizes.size())}, split_sizes, is_initializer);
@@ -38,9 +53,14 @@ void RunTest(int64_t axis, const std::vector split_sizes, const ShapeAn
for (auto& output : outputs) {
auto& shape = output.first;
auto& data = output.second;
- std::ostringstream oss;
- oss << "output" << i++;
- test.AddOutput(oss.str().c_str(), shape, data);
+ const auto output_name = MakeString("output", i++);
+
+ if constexpr (is_bool_data) {
+ auto data_array = bool_vector_to_array(data);
+ test.AddOutput(output_name.c_str(), shape, data_array.get(), data.size());
+ } else {
+ test.AddOutput(output_name.c_str(), shape, data);
+ }
}
std::unordered_set excluded_providers;
if (!is_tensorrt_supported) {
@@ -49,83 +69,73 @@ void RunTest(int64_t axis, const std::vector split_sizes, const ShapeAn
test.Run(expect_failure ? ExpectResult::kExpectFailure : ExpectResult::kExpectSuccess, err_msg, excluded_providers);
}
-TEST(SplitOperatorTest, Axis0EqualSplitFloat) {
- constexpr int64_t axis = 0;
- std::vector outputs;
+template
+[[maybe_unused]] constexpr bool dependent_false_v = false;
- // input shape and data
- ShapeAndFloatData input = {{4, 2}, // shape
- {1.f, 2.f,
- 3.f, 4.f,
- 5.f, 6.f,
- 7.f, 8.f}};
-
- outputs.push_back({{2, 2},
- {1.f, 2.f,
- 3.f, 4.f}});
-
- outputs.push_back({{2, 2},
- {5.f, 6.f,
- 7.f, 8.f}});
-
- RunTest(axis, {}, input, outputs, false); //TensorRT parser: Assertion failed: axis != BATCH_DIM
+template
+constexpr T ValueFromIdx(size_t idx) {
+ if constexpr (std::is_same_v) {
+ const char c = gsl::narrow_cast('a' + idx);
+ return std::string(1, c);
+ } else if constexpr (std::is_same_v) {
+ return (idx & 1) == 1;
+ } else if constexpr (std::is_integral_v || std::is_floating_point_v) {
+ return gsl::narrow_cast(idx);
+ } else if constexpr (std::is_same_v || std::is_same_v) {
+ return T{static_cast(idx)};
+ } else {
+ static_assert(dependent_false_v, "unsupported type");
+ }
}
-template ::value, T>::type>
-static void SplitTestInt() {
+template
+void SplitTestAxis0EqualSplit(bool use_opset_13 = false) {
+ SCOPED_TRACE(onnxruntime::MakeString("data type: ", utils::ToTensorProtoElementType()));
+
constexpr int64_t axis = 0;
std::vector> outputs;
+ const auto V = ValueFromIdx;
+
// input shape and data
ShapeAndData input = {{4, 2}, // shape
- {1, 2,
- 3, 4,
- 5, 6,
- 7, 8}};
+ {V(1), V(2),
+ V(3), V(4),
+ V(5), V(6),
+ V(7), V(8)}};
outputs.push_back({{2, 2},
- {1, 2,
- 3, 4}});
+ {V(1), V(2),
+ V(3), V(4)}});
outputs.push_back({{2, 2},
- {5, 6,
- 7, 8}});
+ {V(5), V(6),
+ V(7), V(8)}});
- RunTest(axis, {}, input, outputs, false); //TensorRT parser: Assertion failed: axis != BATCH_DIM
+ RunTest(axis, {}, input, outputs,
+ // TensorRT parser: Assertion failed: axis != BATCH_DIM
+ false, // is_tensorrt_supported
+ false, // expect_failure
+ use_opset_13); // split_as_input
}
-TEST(SplitOperatorTest, Axis0EqualSplitInt8) {
- SplitTestInt();
-}
+} // namespace
-TEST(SplitOperatorTest, Axis0EqualSplitInt32) {
- SplitTestInt();
-}
-
-TEST(SplitOperatorTest, Axis0EqualSplitInt64) {
- SplitTestInt();
-}
-
-TEST(SplitOperatorTest, Axis0EqualSplitString) {
- constexpr int64_t axis = 0;
- std::vector outputs;
-
- // input shape and data
- ShapeAndStringData input = {{4, 2}, // shape
- {"a", "b",
- "c", "d",
- "e", "f",
- "g", "h"}};
-
- outputs.push_back({{2, 2},
- {"a", "b",
- "c", "d"}});
-
- outputs.push_back({{2, 2},
- {"e", "f",
- "g", "h"}});
-
- RunTest(axis, {}, input, outputs, false); //TensorRT parser: Assertion failed: axis != BATCH_DIM
+TEST(SplitOperatorTest, Axis0EqualSplit) {
+ SplitTestAxis0EqualSplit();
+ SplitTestAxis0EqualSplit();
+ SplitTestAxis0EqualSplit();
+ SplitTestAxis0EqualSplit(true); // BFloat16 added in opset 13
+ SplitTestAxis0EqualSplit();
+ SplitTestAxis0EqualSplit();
+ SplitTestAxis0EqualSplit();
+ SplitTestAxis0EqualSplit();
+ SplitTestAxis0EqualSplit();
+ SplitTestAxis0EqualSplit();
+ SplitTestAxis0EqualSplit();
+ SplitTestAxis0EqualSplit();
+ SplitTestAxis0EqualSplit();
+ SplitTestAxis0EqualSplit();
}
TEST(SplitOperatorTest, Axis0UnequalSplitFloat) {