Add support for other data types to Split CPU kernel. (#13900)

Split copies data - we can add support for all data types without too much binary size impact by using data type size-based implementations. The DispatchStridedCopy() function used here does this.
This commit is contained in:
Edward Chen 2022-12-12 09:29:15 -08:00 committed by GitHub
parent 2cb12caf93
commit 8cfbc4fe91
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 130 additions and 179 deletions

View file

@ -329,9 +329,9 @@ Do not modify directly.*
|Softsign|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float)|
|SpaceToDepth|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float)|
|||[1, 12]|**T** = tensor(double), tensor(float)|
|Split|*in* input:**T**<br> *in* split:**T**<br> *out* outputs...:**T**<br><br>or<br><br>*in* input:**T**<br> *in* split:**tensor(int64)**<br> *out* outputs:**T**<br><br>or<br><br>*in* input:**T**<br> *out* outputs:**T**|13+|**T** = tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint8)|
|||[11, 12]|**T** = tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint8)|
|||[2, 10]|**T** = tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint8)|
|Split|*in* input:**T**<br> *in* split:**T**<br> *out* outputs...:**T**<br><br>or<br><br>*in* input:**T**<br> *in* split:**tensor(int64)**<br> *out* outputs:**T**<br><br>or<br><br>*in* input:**T**<br> *out* outputs:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[2, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|SplitToSequence|*in* input:**T**<br> *in* split:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))<br/> **T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string)|
|Sqrt|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
|||[6, 12]|**T** = tensor(double), tensor(float)|

View file

@ -122,19 +122,23 @@ void StridedCopy(concurrency::ThreadPool* thread_pool,
!copy_shape.empty(),
"src and dst must have same shape and not be rank 0.");
const std::size_t dims = copy_shape.size();
// We will iterate over the output dimensions
int64_t num_iterations = 1;
for (std::size_t dim = 0; dim < dims; dim++) {
num_iterations *= copy_shape[dim];
const int64_t total_num_elements_to_copy = copy_shape_in.Size();
ORT_ENFORCE(total_num_elements_to_copy >= 0, "copy shape must have non-negative size");
if (total_num_elements_to_copy == 0) {
// empty edge case
return;
}
if (num_iterations <= 1) {
if (total_num_elements_to_copy == 1) {
// scalar edge case
dst[0] = src[0];
return;
}
const std::size_t dims = copy_shape.size();
// TODOs for when we have strided tensors:
// - Reorder dimensions so that we iterate along the smallest strides first
@ -151,7 +155,7 @@ void StridedCopy(concurrency::ThreadPool* thread_pool,
std::ptrdiff_t contiguous_span_size = static_cast<std::ptrdiff_t>(dims == 2 ? copy_shape[1] : copy_shape[0]);
concurrency::ThreadPool::TryParallelFor(
thread_pool, static_cast<std::ptrdiff_t>(num_iterations),
thread_pool, static_cast<std::ptrdiff_t>(total_num_elements_to_copy),
{static_cast<float>(sizeof(T)), static_cast<float>(sizeof(T)), 1.0F},
[src_stride, dst_stride, dst, src, contiguous_span_size](std::ptrdiff_t first, std::ptrdiff_t last) {
// get the current inner and outer index
@ -196,7 +200,7 @@ void StridedCopy(concurrency::ThreadPool* thread_pool,
const TensorShapeVector& const_copy_shape = copy_shape;
concurrency::ThreadPool::TryParallelFor(
thread_pool, static_cast<std::ptrdiff_t>(num_iterations),
thread_pool, static_cast<std::ptrdiff_t>(total_num_elements_to_copy),
{static_cast<float>(sizeof(T)), static_cast<float>(sizeof(T)), 1.0F},
[&const_copy_shape, &const_dst_strides, dst, src, &const_src_strides, dims](std::ptrdiff_t first,
std::ptrdiff_t last) {
@ -234,6 +238,7 @@ inline bool StridedCopyIfEnabled(concurrency::ThreadPool* thread_pool,
const TensorShapeVector& dst_strides,
const TensorShape& copy_shape,
const Tensor& src,
std::ptrdiff_t src_offset,
const TensorShapeVector& src_strides) {
constexpr bool enabled = utils::HasTypeWithSameSize<EnabledTypes, T>();
if constexpr (enabled) {
@ -242,7 +247,7 @@ inline bool StridedCopyIfEnabled(concurrency::ThreadPool* thread_pool,
StridedCopy<T>(thread_pool,
reinterpret_cast<T*>(dst.MutableDataRaw()) + dst_offset,
dst_strides, copy_shape,
reinterpret_cast<const T*>(src.DataRaw()),
reinterpret_cast<const T*>(src.DataRaw()) + src_offset,
src_strides);
}
@ -259,40 +264,41 @@ Status DispatchStridedCopy(concurrency::ThreadPool* thread_pool,
const TensorShapeVector& dst_strides,
const TensorShape& copy_shape,
const Tensor& src,
std::ptrdiff_t src_offset,
const TensorShapeVector& src_strides) {
ORT_ENFORCE(dst.DataType() == src.DataType(), "src and dst types must match");
bool supported = false;
if (src.IsDataTypeString()) {
if (utils::HasType<EnabledDataTypes, std::string>()) {
if constexpr (utils::HasType<EnabledDataTypes, std::string>()) {
supported = true;
StridedCopy(thread_pool, dst.MutableData<std::string>() + dst_offset, dst_strides, copy_shape,
src.Data<std::string>(), src_strides);
src.Data<std::string>() + src_offset, src_strides);
}
} else {
const auto element_size = src.DataType()->Size();
switch (element_size) {
case sizeof(uint32_t):
supported = StridedCopyIfEnabled<EnabledDataTypes, uint32_t>(thread_pool, dst, dst_offset, dst_strides,
copy_shape, src, src_strides);
copy_shape, src, src_offset, src_strides);
break;
case sizeof(uint64_t):
supported = StridedCopyIfEnabled<EnabledDataTypes, uint64_t>(thread_pool, dst, dst_offset, dst_strides,
copy_shape, src, src_strides);
copy_shape, src, src_offset, src_strides);
break;
case sizeof(uint16_t):
supported = StridedCopyIfEnabled<EnabledDataTypes, uint16_t>(thread_pool, dst, dst_offset, dst_strides,
copy_shape, src, src_strides);
copy_shape, src, src_offset, src_strides);
break;
case sizeof(uint8_t):
static_assert(sizeof(bool) == sizeof(uint8_t), "Need to enable separate case for 'bool' on this platform.");
supported = StridedCopyIfEnabled<EnabledDataTypes, uint8_t>(thread_pool, dst, dst_offset, dst_strides,
copy_shape, src, src_strides);
copy_shape, src, src_offset, src_strides);
break;
// It's possible that bool is not 1 byte. static_assert above checks if we need to enable this on a platform.
//case sizeof(bool):
// case sizeof(bool):
// supported = StridedCopyIfEnabled<EnabledDataTypes, bool>(thread_pool, dst, dst_offset, dst_strides,
// copy_shape, src, src_strides);
// copy_shape, src, src_offset, src_strides);
// break;
default:
// leave 'supported' as false

View file

@ -56,8 +56,10 @@ common::Status CPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int /
auto src_stride_vec = src.Strides();
onnxruntime::TensorShapeVector dst_stride{dst_stride_vec.begin(), dst_stride_vec.end()};
onnxruntime::TensorShapeVector src_stride{src_stride_vec.begin(), src_stride_vec.end()};
return DispatchStridedCopy<element_type_lists::All>(nullptr, dst, dst.ByteOffset(), dst_stride, src.Shape(), src,
src_stride);
return DispatchStridedCopy<element_type_lists::All>(nullptr,
dst, 0, dst_stride,
src.Shape(),
src, 0, src_stride);
} else {
#endif
// Copying only happens between two same size tensors.

View file

@ -266,10 +266,11 @@ Status ConcatBase::ComputeImpl(Prepare& p, OpKernelContext* ctx) const {
// parallel copy the data across
auto status = DispatchStridedCopy<EnabledDataTypes>(ctx->GetOperatorThreadPool(),
*p.output_tensor,
onnxruntime::narrow<size_t>(initial_output_offset),
onnxruntime::narrow<ptrdiff_t>(initial_output_offset),
output_strides_for_copy,
prep.tensor->Shape(),
*prep.tensor,
0, // src_offset
StridesForTensor(*prep.tensor));
ORT_RETURN_IF_ERROR(status);

View file

@ -1,12 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <core/common/safeint.h>
#include "core/providers/cpu/tensor/split.h"
#include "core/common/gsl.h"
#include "core/common/narrow.h"
#include "core/common/gsl.h"
#include "core/common/safeint.h"
#include "core/framework/copy.h"
#include "core/framework/element_type_lists.h"
#include "core/framework/op_kernel_type_control_utils.h"
#include "core/providers/common.h"
#include "core/providers/op_kernel_type_control.h"
@ -16,9 +17,9 @@
namespace onnxruntime {
namespace op_kernel_type_control {
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Split, Input, 0,
float, int8_t, int32_t, int64_t, uint8_t, std::string);
element_type_lists::All);
ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Split, Input, 0,
int32_t, int64_t);
@ -93,63 +94,25 @@ Status SplitBase::PrepareForCompute(const TensorShape& input_shape, int num_outp
Status Split::Compute(OpKernelContext* context) const {
const Tensor& input = *context->Input<Tensor>(0);
Status status;
// Note: The non-string implementations can probably be based on data type size.
if (input.IsDataType<float>())
status = ComputeImpl<float>(*context, input);
else if (input.IsDataType<int32_t>())
status = ComputeImpl<int32_t>(*context, input);
else if (input.IsDataType<int64_t>())
status = ComputeImpl<int64_t>(*context, input);
else if (input.IsDataType<uint8_t>())
status = ComputeImpl<uint8_t>(*context, input);
else if (input.IsDataType<int8_t>())
status = ComputeImpl<int8_t>(*context, input);
else if (input.IsDataTypeString())
status = ComputeImpl<std::string>(*context, input);
else
ORT_THROW("Split operator does not support ", input.DataType(), " yet");
return status;
}
template <typename T>
inline void copy_data(const T* src, T* dst, size_t count) {
memcpy(dst, src, count * sizeof(T));
}
template <>
inline void copy_data<std::string>(const std::string* src, std::string* dst, size_t count) {
const std::string* end = src + count;
std::copy(src, end, dst);
}
template <typename T>
Status Split::ComputeImpl(OpKernelContext& context, const Tensor& input) const {
if (!utils::HasType<EnabledSplitDataTypes, T>()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Data type is not supported in this build.");
}
auto& input_shape = input.Shape();
auto num_outputs = context.OutputCount();
auto num_outputs = context->OutputCount();
int64_t axis = axis_;
int before_dims = 0;
int after_dims_including_split_axis = 0;
int after_dims_excluding_split = 0;
std::vector<int64_t> split_sizes;
const Tensor* split_tensor = context.Input<Tensor>(1);
const Tensor* split_tensor = context->Input<Tensor>(1);
if (split_tensor != nullptr) {
//override the attribute value with the input value for split
ORT_ENFORCE(split_tensor->Shape().NumDimensions() == 1, "An split tensor must be a vector tensor.");
// override the attribute value with the input value for split
ORT_ENFORCE(split_tensor->Shape().NumDimensions() == 1, "The split tensor must be a vector tensor.");
auto nDims = static_cast<size_t>(split_tensor->Shape()[0]);
const auto* data = split_tensor->Data<int64_t>();
split_sizes.assign(data, data + nDims);
} else {
split_sizes.assign(split_sizes_.begin(), split_sizes_.end());
}
ORT_RETURN_IF_ERROR(PrepareForCompute(input_shape,
num_outputs,
axis,
@ -158,32 +121,27 @@ Status Split::ComputeImpl(OpKernelContext& context, const Tensor& input) const {
after_dims_excluding_split,
split_sizes));
const auto input_strides = StridesForTensor(input);
// copy dimensions so we can update the selected axis in place
auto output_dimensions = input_shape.AsShapeVector();
int64_t input_offset = 0;
const T* input_data = input.Data<T>();
SafeInt<ptrdiff_t> input_offset = 0;
for (int i = 0; i < num_outputs; ++i) {
// update size of dimension for axis we're splitting on
auto split_size = narrow<int>(split_sizes[i]);
output_dimensions[onnxruntime::narrow<size_t>(axis)] = split_size;
output_dimensions[narrow<size_t>(axis)] = split_size;
Tensor* output = context.Output(i, TensorShape{output_dimensions});
T* output_data = output->MutableData<T>();
Tensor* output = context->Output(i, TensorShape{output_dimensions});
const auto output_strides = StridesForTensor(*output);
::onnxruntime::math::CopyMatrix<T>(
before_dims, // M
split_size * after_dims_excluding_split, // N
static_cast<const T*>(input_data + input_offset), // A
after_dims_including_split_axis, // lda
static_cast<T*>(output_data), // B
split_size * after_dims_excluding_split, // ldb
[](const T* src, T* dst, size_t count) {
copy_data<T>(src, dst, count);
});
ORT_RETURN_IF_ERROR(DispatchStridedCopy<EnabledSplitDataTypes>(context->GetOperatorThreadPool(),
*output, /* dst_offset */ 0, output_strides,
output->Shape(),
input, input_offset, input_strides));
input_offset += static_cast<int64_t>(split_size) * after_dims_excluding_split; // offset by the N data we used in this iteration
input_offset += SafeInt<ptrdiff_t>(split_size) * after_dims_excluding_split; // offset by the data we used in this iteration
}
return Status::OK();

View file

@ -46,10 +46,6 @@ class Split final : public OpKernel, public SplitBase {
Split(const OpKernelInfo& info) : OpKernel(info), SplitBase(info) {}
Status Compute(OpKernelContext* context) const override;
private:
template <typename T>
Status ComputeImpl(OpKernelContext& context, const Tensor& input) const;
};
} // namespace onnxruntime

View file

@ -144,17 +144,6 @@ TEST_F(CopyTest, CoalesceTensorsTest) {
ASSERT_THAT(strides_b, testing::ElementsAre(1));
ASSERT_THAT(shape, testing::ElementsAre(15));
}
{
TensorShapeVector strides_a{3, 3, 3, 1};
TensorShapeVector strides_b{3, 3, 3, 1};
TensorShapeVector shape{1, 5, 1, 3};
CoalesceDimensions({strides_a, strides_b}, shape);
ASSERT_THAT(strides_a, testing::ElementsAre(1));
ASSERT_THAT(strides_b, testing::ElementsAre(1));
ASSERT_THAT(shape, testing::ElementsAre(15));
}
{
TensorShapeVector strides_a{320, 1};
TensorShapeVector strides_b{320, 1};
@ -188,17 +177,6 @@ TEST_F(CopyTest, CoalesceTensorsTest) {
ASSERT_THAT(strides_b, testing::ElementsAre(6, 1));
ASSERT_THAT(shape, testing::ElementsAre(5, 3));
}
{
TensorShapeVector strides_a{3, 1};
TensorShapeVector strides_b{6, 1};
TensorShapeVector shape{5, 3};
CoalesceDimensions({strides_a, strides_b}, shape);
ASSERT_THAT(strides_a, testing::ElementsAre(3, 1));
ASSERT_THAT(strides_b, testing::ElementsAre(6, 1));
ASSERT_THAT(shape, testing::ElementsAre(5, 3));
}
{
TensorShapeVector strides_a{4, 1};
TensorShapeVector strides_b{1, 1};

View file

@ -2,11 +2,14 @@
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "core/framework/to_tensor_proto_element_type.h"
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace test {
namespace {
template <class T>
using ShapeAndData = std::pair<const std::vector<int64_t>, const std::vector<T>>;
@ -15,15 +18,27 @@ using ShapeAndStringData = ShapeAndData<std::string>;
using ExpectResult = OpTester::ExpectResult;
template <typename T>
void RunTest(int64_t axis, const std::vector<int64_t> split_sizes, const ShapeAndData<T>& input,
void RunTest(int64_t axis, const std::vector<int64_t>& split_sizes, const ShapeAndData<T>& input,
const std::vector<ShapeAndData<T>>& outputs, bool is_tensorrt_supported = true,
bool expect_failure = false, bool split_as_input = false,
bool is_initializer = true, const std::string& err_msg = {}, bool skip_split_if_empty = true) {
int opset_version = split_as_input ? 13 : 7;
OpTester test("Split", opset_version, onnxruntime::kOnnxDomain);
constexpr bool is_bool_data = std::is_same_v<T, bool>;
[[maybe_unused]] auto bool_vector_to_array = [](const std::vector<bool>& v) -> std::unique_ptr<bool[]> {
auto a = std::make_unique<bool[]>(v.size());
std::copy(v.begin(), v.end(), a.get());
return a;
};
test.AddAttribute("axis", axis);
test.AddInput<T>("input", input.first, input.second);
if constexpr (is_bool_data) {
auto input_array = bool_vector_to_array(input.second);
test.AddInput<T>("input", input.first, input_array.get(), input.second.size());
} else {
test.AddInput<T>("input", input.first, input.second);
}
if (!split_sizes.empty()) {
if (split_as_input) {
test.AddInput<int64_t>("split", {static_cast<int64_t>(split_sizes.size())}, split_sizes, is_initializer);
@ -38,9 +53,14 @@ void RunTest(int64_t axis, const std::vector<int64_t> split_sizes, const ShapeAn
for (auto& output : outputs) {
auto& shape = output.first;
auto& data = output.second;
std::ostringstream oss;
oss << "output" << i++;
test.AddOutput<T>(oss.str().c_str(), shape, data);
const auto output_name = MakeString("output", i++);
if constexpr (is_bool_data) {
auto data_array = bool_vector_to_array(data);
test.AddOutput<T>(output_name.c_str(), shape, data_array.get(), data.size());
} else {
test.AddOutput<T>(output_name.c_str(), shape, data);
}
}
std::unordered_set<std::string> excluded_providers;
if (!is_tensorrt_supported) {
@ -49,83 +69,73 @@ void RunTest(int64_t axis, const std::vector<int64_t> split_sizes, const ShapeAn
test.Run(expect_failure ? ExpectResult::kExpectFailure : ExpectResult::kExpectSuccess, err_msg, excluded_providers);
}
TEST(SplitOperatorTest, Axis0EqualSplitFloat) {
constexpr int64_t axis = 0;
std::vector<ShapeAndFloatData> outputs;
template <typename>
[[maybe_unused]] constexpr bool dependent_false_v = false;
// input shape and data
ShapeAndFloatData input = {{4, 2}, // shape
{1.f, 2.f,
3.f, 4.f,
5.f, 6.f,
7.f, 8.f}};
outputs.push_back({{2, 2},
{1.f, 2.f,
3.f, 4.f}});
outputs.push_back({{2, 2},
{5.f, 6.f,
7.f, 8.f}});
RunTest<float>(axis, {}, input, outputs, false); //TensorRT parser: Assertion failed: axis != BATCH_DIM
template <typename T>
constexpr T ValueFromIdx(size_t idx) {
if constexpr (std::is_same_v<T, std::string>) {
const char c = gsl::narrow_cast<char>('a' + idx);
return std::string(1, c);
} else if constexpr (std::is_same_v<T, bool>) {
return (idx & 1) == 1;
} else if constexpr (std::is_integral_v<T> || std::is_floating_point_v<T>) {
return gsl::narrow_cast<T>(idx);
} else if constexpr (std::is_same_v<T, MLFloat16> || std::is_same_v<T, BFloat16>) {
return T{static_cast<float>(idx)};
} else {
static_assert(dependent_false_v<T>, "unsupported type");
}
}
template <typename T, typename = typename std::enable_if<std::is_integral<T>::value, T>::type>
static void SplitTestInt() {
template <typename T>
void SplitTestAxis0EqualSplit(bool use_opset_13 = false) {
SCOPED_TRACE(onnxruntime::MakeString("data type: ", utils::ToTensorProtoElementType<T>()));
constexpr int64_t axis = 0;
std::vector<ShapeAndData<T>> outputs;
const auto V = ValueFromIdx<T>;
// input shape and data
ShapeAndData<T> input = {{4, 2}, // shape
{1, 2,
3, 4,
5, 6,
7, 8}};
{V(1), V(2),
V(3), V(4),
V(5), V(6),
V(7), V(8)}};
outputs.push_back({{2, 2},
{1, 2,
3, 4}});
{V(1), V(2),
V(3), V(4)}});
outputs.push_back({{2, 2},
{5, 6,
7, 8}});
{V(5), V(6),
V(7), V(8)}});
RunTest<T>(axis, {}, input, outputs, false); //TensorRT parser: Assertion failed: axis != BATCH_DIM
RunTest<T>(axis, {}, input, outputs,
// TensorRT parser: Assertion failed: axis != BATCH_DIM
false, // is_tensorrt_supported
false, // expect_failure
use_opset_13); // split_as_input
}
TEST(SplitOperatorTest, Axis0EqualSplitInt8) {
SplitTestInt<int8_t>();
}
} // namespace
TEST(SplitOperatorTest, Axis0EqualSplitInt32) {
SplitTestInt<int32_t>();
}
TEST(SplitOperatorTest, Axis0EqualSplitInt64) {
SplitTestInt<int64_t>();
}
TEST(SplitOperatorTest, Axis0EqualSplitString) {
constexpr int64_t axis = 0;
std::vector<ShapeAndStringData> outputs;
// input shape and data
ShapeAndStringData input = {{4, 2}, // shape
{"a", "b",
"c", "d",
"e", "f",
"g", "h"}};
outputs.push_back({{2, 2},
{"a", "b",
"c", "d"}});
outputs.push_back({{2, 2},
{"e", "f",
"g", "h"}});
RunTest<std::string>(axis, {}, input, outputs, false); //TensorRT parser: Assertion failed: axis != BATCH_DIM
TEST(SplitOperatorTest, Axis0EqualSplit) {
SplitTestAxis0EqualSplit<float>();
SplitTestAxis0EqualSplit<double>();
SplitTestAxis0EqualSplit<MLFloat16>();
SplitTestAxis0EqualSplit<BFloat16>(true); // BFloat16 added in opset 13
SplitTestAxis0EqualSplit<int8_t>();
SplitTestAxis0EqualSplit<int16_t>();
SplitTestAxis0EqualSplit<int32_t>();
SplitTestAxis0EqualSplit<int64_t>();
SplitTestAxis0EqualSplit<uint8_t>();
SplitTestAxis0EqualSplit<uint16_t>();
SplitTestAxis0EqualSplit<uint32_t>();
SplitTestAxis0EqualSplit<uint64_t>();
SplitTestAxis0EqualSplit<bool>();
SplitTestAxis0EqualSplit<std::string>();
}
TEST(SplitOperatorTest, Axis0UnequalSplitFloat) {