mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Add support for other data types to Split CPU kernel. (#13900)
Split copies data - we can add support for all data types without too much binary size impact by using data type size-based implementations. The DispatchStridedCopy() function used here does this.
This commit is contained in:
parent
2cb12caf93
commit
8cfbc4fe91
8 changed files with 130 additions and 179 deletions
|
|
@ -329,9 +329,9 @@ Do not modify directly.*
|
|||
|Softsign|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float)|
|
||||
|SpaceToDepth|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float)|
|
||||
|||[1, 12]|**T** = tensor(double), tensor(float)|
|
||||
|Split|*in* input:**T**<br> *in* split:**T**<br> *out* outputs...:**T**<br><br>or<br><br>*in* input:**T**<br> *in* split:**tensor(int64)**<br> *out* outputs:**T**<br><br>or<br><br>*in* input:**T**<br> *out* outputs:**T**|13+|**T** = tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint8)|
|
||||
|||[11, 12]|**T** = tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint8)|
|
||||
|||[2, 10]|**T** = tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint8)|
|
||||
|Split|*in* input:**T**<br> *in* split:**T**<br> *out* outputs...:**T**<br><br>or<br><br>*in* input:**T**<br> *in* split:**tensor(int64)**<br> *out* outputs:**T**<br><br>or<br><br>*in* input:**T**<br> *out* outputs:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||[2, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|SplitToSequence|*in* input:**T**<br> *in* split:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))<br/> **T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string)|
|
||||
|Sqrt|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
|
||||
|||[6, 12]|**T** = tensor(double), tensor(float)|
|
||||
|
|
|
|||
|
|
@ -122,19 +122,23 @@ void StridedCopy(concurrency::ThreadPool* thread_pool,
|
|||
!copy_shape.empty(),
|
||||
"src and dst must have same shape and not be rank 0.");
|
||||
|
||||
const std::size_t dims = copy_shape.size();
|
||||
// We will iterate over the output dimensions
|
||||
int64_t num_iterations = 1;
|
||||
for (std::size_t dim = 0; dim < dims; dim++) {
|
||||
num_iterations *= copy_shape[dim];
|
||||
const int64_t total_num_elements_to_copy = copy_shape_in.Size();
|
||||
|
||||
ORT_ENFORCE(total_num_elements_to_copy >= 0, "copy shape must have non-negative size");
|
||||
|
||||
if (total_num_elements_to_copy == 0) {
|
||||
// empty edge case
|
||||
return;
|
||||
}
|
||||
|
||||
if (num_iterations <= 1) {
|
||||
if (total_num_elements_to_copy == 1) {
|
||||
// scalar edge case
|
||||
dst[0] = src[0];
|
||||
return;
|
||||
}
|
||||
|
||||
const std::size_t dims = copy_shape.size();
|
||||
|
||||
// TODOs for when we have strided tensors:
|
||||
// - Reorder dimensions so that we iterate along the smallest strides first
|
||||
|
||||
|
|
@ -151,7 +155,7 @@ void StridedCopy(concurrency::ThreadPool* thread_pool,
|
|||
std::ptrdiff_t contiguous_span_size = static_cast<std::ptrdiff_t>(dims == 2 ? copy_shape[1] : copy_shape[0]);
|
||||
|
||||
concurrency::ThreadPool::TryParallelFor(
|
||||
thread_pool, static_cast<std::ptrdiff_t>(num_iterations),
|
||||
thread_pool, static_cast<std::ptrdiff_t>(total_num_elements_to_copy),
|
||||
{static_cast<float>(sizeof(T)), static_cast<float>(sizeof(T)), 1.0F},
|
||||
[src_stride, dst_stride, dst, src, contiguous_span_size](std::ptrdiff_t first, std::ptrdiff_t last) {
|
||||
// get the current inner and outer index
|
||||
|
|
@ -196,7 +200,7 @@ void StridedCopy(concurrency::ThreadPool* thread_pool,
|
|||
const TensorShapeVector& const_copy_shape = copy_shape;
|
||||
|
||||
concurrency::ThreadPool::TryParallelFor(
|
||||
thread_pool, static_cast<std::ptrdiff_t>(num_iterations),
|
||||
thread_pool, static_cast<std::ptrdiff_t>(total_num_elements_to_copy),
|
||||
{static_cast<float>(sizeof(T)), static_cast<float>(sizeof(T)), 1.0F},
|
||||
[&const_copy_shape, &const_dst_strides, dst, src, &const_src_strides, dims](std::ptrdiff_t first,
|
||||
std::ptrdiff_t last) {
|
||||
|
|
@ -234,6 +238,7 @@ inline bool StridedCopyIfEnabled(concurrency::ThreadPool* thread_pool,
|
|||
const TensorShapeVector& dst_strides,
|
||||
const TensorShape& copy_shape,
|
||||
const Tensor& src,
|
||||
std::ptrdiff_t src_offset,
|
||||
const TensorShapeVector& src_strides) {
|
||||
constexpr bool enabled = utils::HasTypeWithSameSize<EnabledTypes, T>();
|
||||
if constexpr (enabled) {
|
||||
|
|
@ -242,7 +247,7 @@ inline bool StridedCopyIfEnabled(concurrency::ThreadPool* thread_pool,
|
|||
StridedCopy<T>(thread_pool,
|
||||
reinterpret_cast<T*>(dst.MutableDataRaw()) + dst_offset,
|
||||
dst_strides, copy_shape,
|
||||
reinterpret_cast<const T*>(src.DataRaw()),
|
||||
reinterpret_cast<const T*>(src.DataRaw()) + src_offset,
|
||||
src_strides);
|
||||
}
|
||||
|
||||
|
|
@ -259,40 +264,41 @@ Status DispatchStridedCopy(concurrency::ThreadPool* thread_pool,
|
|||
const TensorShapeVector& dst_strides,
|
||||
const TensorShape& copy_shape,
|
||||
const Tensor& src,
|
||||
std::ptrdiff_t src_offset,
|
||||
const TensorShapeVector& src_strides) {
|
||||
ORT_ENFORCE(dst.DataType() == src.DataType(), "src and dst types must match");
|
||||
|
||||
bool supported = false;
|
||||
if (src.IsDataTypeString()) {
|
||||
if (utils::HasType<EnabledDataTypes, std::string>()) {
|
||||
if constexpr (utils::HasType<EnabledDataTypes, std::string>()) {
|
||||
supported = true;
|
||||
StridedCopy(thread_pool, dst.MutableData<std::string>() + dst_offset, dst_strides, copy_shape,
|
||||
src.Data<std::string>(), src_strides);
|
||||
src.Data<std::string>() + src_offset, src_strides);
|
||||
}
|
||||
} else {
|
||||
const auto element_size = src.DataType()->Size();
|
||||
switch (element_size) {
|
||||
case sizeof(uint32_t):
|
||||
supported = StridedCopyIfEnabled<EnabledDataTypes, uint32_t>(thread_pool, dst, dst_offset, dst_strides,
|
||||
copy_shape, src, src_strides);
|
||||
copy_shape, src, src_offset, src_strides);
|
||||
break;
|
||||
case sizeof(uint64_t):
|
||||
supported = StridedCopyIfEnabled<EnabledDataTypes, uint64_t>(thread_pool, dst, dst_offset, dst_strides,
|
||||
copy_shape, src, src_strides);
|
||||
copy_shape, src, src_offset, src_strides);
|
||||
break;
|
||||
case sizeof(uint16_t):
|
||||
supported = StridedCopyIfEnabled<EnabledDataTypes, uint16_t>(thread_pool, dst, dst_offset, dst_strides,
|
||||
copy_shape, src, src_strides);
|
||||
copy_shape, src, src_offset, src_strides);
|
||||
break;
|
||||
case sizeof(uint8_t):
|
||||
static_assert(sizeof(bool) == sizeof(uint8_t), "Need to enable separate case for 'bool' on this platform.");
|
||||
supported = StridedCopyIfEnabled<EnabledDataTypes, uint8_t>(thread_pool, dst, dst_offset, dst_strides,
|
||||
copy_shape, src, src_strides);
|
||||
copy_shape, src, src_offset, src_strides);
|
||||
break;
|
||||
// It's possible that bool is not 1 byte. static_assert above checks if we need to enable this on a platform.
|
||||
//case sizeof(bool):
|
||||
// case sizeof(bool):
|
||||
// supported = StridedCopyIfEnabled<EnabledDataTypes, bool>(thread_pool, dst, dst_offset, dst_strides,
|
||||
// copy_shape, src, src_strides);
|
||||
// copy_shape, src, src_offset, src_strides);
|
||||
// break;
|
||||
default:
|
||||
// leave 'supported' as false
|
||||
|
|
|
|||
|
|
@ -56,8 +56,10 @@ common::Status CPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int /
|
|||
auto src_stride_vec = src.Strides();
|
||||
onnxruntime::TensorShapeVector dst_stride{dst_stride_vec.begin(), dst_stride_vec.end()};
|
||||
onnxruntime::TensorShapeVector src_stride{src_stride_vec.begin(), src_stride_vec.end()};
|
||||
return DispatchStridedCopy<element_type_lists::All>(nullptr, dst, dst.ByteOffset(), dst_stride, src.Shape(), src,
|
||||
src_stride);
|
||||
return DispatchStridedCopy<element_type_lists::All>(nullptr,
|
||||
dst, 0, dst_stride,
|
||||
src.Shape(),
|
||||
src, 0, src_stride);
|
||||
} else {
|
||||
#endif
|
||||
// Copying only happens between two same size tensors.
|
||||
|
|
|
|||
|
|
@ -266,10 +266,11 @@ Status ConcatBase::ComputeImpl(Prepare& p, OpKernelContext* ctx) const {
|
|||
// parallel copy the data across
|
||||
auto status = DispatchStridedCopy<EnabledDataTypes>(ctx->GetOperatorThreadPool(),
|
||||
*p.output_tensor,
|
||||
onnxruntime::narrow<size_t>(initial_output_offset),
|
||||
onnxruntime::narrow<ptrdiff_t>(initial_output_offset),
|
||||
output_strides_for_copy,
|
||||
prep.tensor->Shape(),
|
||||
*prep.tensor,
|
||||
0, // src_offset
|
||||
StridesForTensor(*prep.tensor));
|
||||
ORT_RETURN_IF_ERROR(status);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <core/common/safeint.h>
|
||||
#include "core/providers/cpu/tensor/split.h"
|
||||
|
||||
#include "core/common/gsl.h"
|
||||
|
||||
#include "core/common/narrow.h"
|
||||
#include "core/common/gsl.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/framework/copy.h"
|
||||
#include "core/framework/element_type_lists.h"
|
||||
#include "core/framework/op_kernel_type_control_utils.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/op_kernel_type_control.h"
|
||||
|
|
@ -16,9 +17,9 @@
|
|||
namespace onnxruntime {
|
||||
|
||||
namespace op_kernel_type_control {
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, Split, Input, 0,
|
||||
float, int8_t, int32_t, int64_t, uint8_t, std::string);
|
||||
element_type_lists::All);
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, Split, Input, 0,
|
||||
int32_t, int64_t);
|
||||
|
|
@ -93,63 +94,25 @@ Status SplitBase::PrepareForCompute(const TensorShape& input_shape, int num_outp
|
|||
|
||||
Status Split::Compute(OpKernelContext* context) const {
|
||||
const Tensor& input = *context->Input<Tensor>(0);
|
||||
|
||||
Status status;
|
||||
|
||||
// Note: The non-string implementations can probably be based on data type size.
|
||||
if (input.IsDataType<float>())
|
||||
status = ComputeImpl<float>(*context, input);
|
||||
else if (input.IsDataType<int32_t>())
|
||||
status = ComputeImpl<int32_t>(*context, input);
|
||||
else if (input.IsDataType<int64_t>())
|
||||
status = ComputeImpl<int64_t>(*context, input);
|
||||
else if (input.IsDataType<uint8_t>())
|
||||
status = ComputeImpl<uint8_t>(*context, input);
|
||||
else if (input.IsDataType<int8_t>())
|
||||
status = ComputeImpl<int8_t>(*context, input);
|
||||
else if (input.IsDataTypeString())
|
||||
status = ComputeImpl<std::string>(*context, input);
|
||||
else
|
||||
ORT_THROW("Split operator does not support ", input.DataType(), " yet");
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void copy_data(const T* src, T* dst, size_t count) {
|
||||
memcpy(dst, src, count * sizeof(T));
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void copy_data<std::string>(const std::string* src, std::string* dst, size_t count) {
|
||||
const std::string* end = src + count;
|
||||
std::copy(src, end, dst);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status Split::ComputeImpl(OpKernelContext& context, const Tensor& input) const {
|
||||
if (!utils::HasType<EnabledSplitDataTypes, T>()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Data type is not supported in this build.");
|
||||
}
|
||||
|
||||
auto& input_shape = input.Shape();
|
||||
auto num_outputs = context.OutputCount();
|
||||
auto num_outputs = context->OutputCount();
|
||||
int64_t axis = axis_;
|
||||
int before_dims = 0;
|
||||
int after_dims_including_split_axis = 0;
|
||||
int after_dims_excluding_split = 0;
|
||||
std::vector<int64_t> split_sizes;
|
||||
|
||||
const Tensor* split_tensor = context.Input<Tensor>(1);
|
||||
const Tensor* split_tensor = context->Input<Tensor>(1);
|
||||
if (split_tensor != nullptr) {
|
||||
//override the attribute value with the input value for split
|
||||
ORT_ENFORCE(split_tensor->Shape().NumDimensions() == 1, "An split tensor must be a vector tensor.");
|
||||
// override the attribute value with the input value for split
|
||||
ORT_ENFORCE(split_tensor->Shape().NumDimensions() == 1, "The split tensor must be a vector tensor.");
|
||||
auto nDims = static_cast<size_t>(split_tensor->Shape()[0]);
|
||||
const auto* data = split_tensor->Data<int64_t>();
|
||||
split_sizes.assign(data, data + nDims);
|
||||
} else {
|
||||
split_sizes.assign(split_sizes_.begin(), split_sizes_.end());
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(input_shape,
|
||||
num_outputs,
|
||||
axis,
|
||||
|
|
@ -158,32 +121,27 @@ Status Split::ComputeImpl(OpKernelContext& context, const Tensor& input) const {
|
|||
after_dims_excluding_split,
|
||||
split_sizes));
|
||||
|
||||
const auto input_strides = StridesForTensor(input);
|
||||
|
||||
// copy dimensions so we can update the selected axis in place
|
||||
auto output_dimensions = input_shape.AsShapeVector();
|
||||
|
||||
int64_t input_offset = 0;
|
||||
const T* input_data = input.Data<T>();
|
||||
SafeInt<ptrdiff_t> input_offset = 0;
|
||||
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
// update size of dimension for axis we're splitting on
|
||||
auto split_size = narrow<int>(split_sizes[i]);
|
||||
output_dimensions[onnxruntime::narrow<size_t>(axis)] = split_size;
|
||||
output_dimensions[narrow<size_t>(axis)] = split_size;
|
||||
|
||||
Tensor* output = context.Output(i, TensorShape{output_dimensions});
|
||||
T* output_data = output->MutableData<T>();
|
||||
Tensor* output = context->Output(i, TensorShape{output_dimensions});
|
||||
const auto output_strides = StridesForTensor(*output);
|
||||
|
||||
::onnxruntime::math::CopyMatrix<T>(
|
||||
before_dims, // M
|
||||
split_size * after_dims_excluding_split, // N
|
||||
static_cast<const T*>(input_data + input_offset), // A
|
||||
after_dims_including_split_axis, // lda
|
||||
static_cast<T*>(output_data), // B
|
||||
split_size * after_dims_excluding_split, // ldb
|
||||
[](const T* src, T* dst, size_t count) {
|
||||
copy_data<T>(src, dst, count);
|
||||
});
|
||||
ORT_RETURN_IF_ERROR(DispatchStridedCopy<EnabledSplitDataTypes>(context->GetOperatorThreadPool(),
|
||||
*output, /* dst_offset */ 0, output_strides,
|
||||
output->Shape(),
|
||||
input, input_offset, input_strides));
|
||||
|
||||
input_offset += static_cast<int64_t>(split_size) * after_dims_excluding_split; // offset by the N data we used in this iteration
|
||||
input_offset += SafeInt<ptrdiff_t>(split_size) * after_dims_excluding_split; // offset by the data we used in this iteration
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -46,10 +46,6 @@ class Split final : public OpKernel, public SplitBase {
|
|||
Split(const OpKernelInfo& info) : OpKernel(info), SplitBase(info) {}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
Status ComputeImpl(OpKernelContext& context, const Tensor& input) const;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -144,17 +144,6 @@ TEST_F(CopyTest, CoalesceTensorsTest) {
|
|||
ASSERT_THAT(strides_b, testing::ElementsAre(1));
|
||||
ASSERT_THAT(shape, testing::ElementsAre(15));
|
||||
}
|
||||
{
|
||||
TensorShapeVector strides_a{3, 3, 3, 1};
|
||||
TensorShapeVector strides_b{3, 3, 3, 1};
|
||||
TensorShapeVector shape{1, 5, 1, 3};
|
||||
|
||||
CoalesceDimensions({strides_a, strides_b}, shape);
|
||||
|
||||
ASSERT_THAT(strides_a, testing::ElementsAre(1));
|
||||
ASSERT_THAT(strides_b, testing::ElementsAre(1));
|
||||
ASSERT_THAT(shape, testing::ElementsAre(15));
|
||||
}
|
||||
{
|
||||
TensorShapeVector strides_a{320, 1};
|
||||
TensorShapeVector strides_b{320, 1};
|
||||
|
|
@ -188,17 +177,6 @@ TEST_F(CopyTest, CoalesceTensorsTest) {
|
|||
ASSERT_THAT(strides_b, testing::ElementsAre(6, 1));
|
||||
ASSERT_THAT(shape, testing::ElementsAre(5, 3));
|
||||
}
|
||||
{
|
||||
TensorShapeVector strides_a{3, 1};
|
||||
TensorShapeVector strides_b{6, 1};
|
||||
TensorShapeVector shape{5, 3};
|
||||
|
||||
CoalesceDimensions({strides_a, strides_b}, shape);
|
||||
|
||||
ASSERT_THAT(strides_a, testing::ElementsAre(3, 1));
|
||||
ASSERT_THAT(strides_b, testing::ElementsAre(6, 1));
|
||||
ASSERT_THAT(shape, testing::ElementsAre(5, 3));
|
||||
}
|
||||
{
|
||||
TensorShapeVector strides_a{4, 1};
|
||||
TensorShapeVector strides_b{1, 1};
|
||||
|
|
|
|||
|
|
@ -2,11 +2,14 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "core/framework/to_tensor_proto_element_type.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
namespace {
|
||||
|
||||
template <class T>
|
||||
using ShapeAndData = std::pair<const std::vector<int64_t>, const std::vector<T>>;
|
||||
|
||||
|
|
@ -15,15 +18,27 @@ using ShapeAndStringData = ShapeAndData<std::string>;
|
|||
using ExpectResult = OpTester::ExpectResult;
|
||||
|
||||
template <typename T>
|
||||
void RunTest(int64_t axis, const std::vector<int64_t> split_sizes, const ShapeAndData<T>& input,
|
||||
void RunTest(int64_t axis, const std::vector<int64_t>& split_sizes, const ShapeAndData<T>& input,
|
||||
const std::vector<ShapeAndData<T>>& outputs, bool is_tensorrt_supported = true,
|
||||
bool expect_failure = false, bool split_as_input = false,
|
||||
bool is_initializer = true, const std::string& err_msg = {}, bool skip_split_if_empty = true) {
|
||||
int opset_version = split_as_input ? 13 : 7;
|
||||
OpTester test("Split", opset_version, onnxruntime::kOnnxDomain);
|
||||
|
||||
constexpr bool is_bool_data = std::is_same_v<T, bool>;
|
||||
[[maybe_unused]] auto bool_vector_to_array = [](const std::vector<bool>& v) -> std::unique_ptr<bool[]> {
|
||||
auto a = std::make_unique<bool[]>(v.size());
|
||||
std::copy(v.begin(), v.end(), a.get());
|
||||
return a;
|
||||
};
|
||||
|
||||
test.AddAttribute("axis", axis);
|
||||
test.AddInput<T>("input", input.first, input.second);
|
||||
if constexpr (is_bool_data) {
|
||||
auto input_array = bool_vector_to_array(input.second);
|
||||
test.AddInput<T>("input", input.first, input_array.get(), input.second.size());
|
||||
} else {
|
||||
test.AddInput<T>("input", input.first, input.second);
|
||||
}
|
||||
if (!split_sizes.empty()) {
|
||||
if (split_as_input) {
|
||||
test.AddInput<int64_t>("split", {static_cast<int64_t>(split_sizes.size())}, split_sizes, is_initializer);
|
||||
|
|
@ -38,9 +53,14 @@ void RunTest(int64_t axis, const std::vector<int64_t> split_sizes, const ShapeAn
|
|||
for (auto& output : outputs) {
|
||||
auto& shape = output.first;
|
||||
auto& data = output.second;
|
||||
std::ostringstream oss;
|
||||
oss << "output" << i++;
|
||||
test.AddOutput<T>(oss.str().c_str(), shape, data);
|
||||
const auto output_name = MakeString("output", i++);
|
||||
|
||||
if constexpr (is_bool_data) {
|
||||
auto data_array = bool_vector_to_array(data);
|
||||
test.AddOutput<T>(output_name.c_str(), shape, data_array.get(), data.size());
|
||||
} else {
|
||||
test.AddOutput<T>(output_name.c_str(), shape, data);
|
||||
}
|
||||
}
|
||||
std::unordered_set<std::string> excluded_providers;
|
||||
if (!is_tensorrt_supported) {
|
||||
|
|
@ -49,83 +69,73 @@ void RunTest(int64_t axis, const std::vector<int64_t> split_sizes, const ShapeAn
|
|||
test.Run(expect_failure ? ExpectResult::kExpectFailure : ExpectResult::kExpectSuccess, err_msg, excluded_providers);
|
||||
}
|
||||
|
||||
TEST(SplitOperatorTest, Axis0EqualSplitFloat) {
|
||||
constexpr int64_t axis = 0;
|
||||
std::vector<ShapeAndFloatData> outputs;
|
||||
template <typename>
|
||||
[[maybe_unused]] constexpr bool dependent_false_v = false;
|
||||
|
||||
// input shape and data
|
||||
ShapeAndFloatData input = {{4, 2}, // shape
|
||||
{1.f, 2.f,
|
||||
3.f, 4.f,
|
||||
5.f, 6.f,
|
||||
7.f, 8.f}};
|
||||
|
||||
outputs.push_back({{2, 2},
|
||||
{1.f, 2.f,
|
||||
3.f, 4.f}});
|
||||
|
||||
outputs.push_back({{2, 2},
|
||||
{5.f, 6.f,
|
||||
7.f, 8.f}});
|
||||
|
||||
RunTest<float>(axis, {}, input, outputs, false); //TensorRT parser: Assertion failed: axis != BATCH_DIM
|
||||
template <typename T>
|
||||
constexpr T ValueFromIdx(size_t idx) {
|
||||
if constexpr (std::is_same_v<T, std::string>) {
|
||||
const char c = gsl::narrow_cast<char>('a' + idx);
|
||||
return std::string(1, c);
|
||||
} else if constexpr (std::is_same_v<T, bool>) {
|
||||
return (idx & 1) == 1;
|
||||
} else if constexpr (std::is_integral_v<T> || std::is_floating_point_v<T>) {
|
||||
return gsl::narrow_cast<T>(idx);
|
||||
} else if constexpr (std::is_same_v<T, MLFloat16> || std::is_same_v<T, BFloat16>) {
|
||||
return T{static_cast<float>(idx)};
|
||||
} else {
|
||||
static_assert(dependent_false_v<T>, "unsupported type");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename = typename std::enable_if<std::is_integral<T>::value, T>::type>
|
||||
static void SplitTestInt() {
|
||||
template <typename T>
|
||||
void SplitTestAxis0EqualSplit(bool use_opset_13 = false) {
|
||||
SCOPED_TRACE(onnxruntime::MakeString("data type: ", utils::ToTensorProtoElementType<T>()));
|
||||
|
||||
constexpr int64_t axis = 0;
|
||||
std::vector<ShapeAndData<T>> outputs;
|
||||
|
||||
const auto V = ValueFromIdx<T>;
|
||||
|
||||
// input shape and data
|
||||
ShapeAndData<T> input = {{4, 2}, // shape
|
||||
{1, 2,
|
||||
3, 4,
|
||||
5, 6,
|
||||
7, 8}};
|
||||
{V(1), V(2),
|
||||
V(3), V(4),
|
||||
V(5), V(6),
|
||||
V(7), V(8)}};
|
||||
|
||||
outputs.push_back({{2, 2},
|
||||
{1, 2,
|
||||
3, 4}});
|
||||
{V(1), V(2),
|
||||
V(3), V(4)}});
|
||||
|
||||
outputs.push_back({{2, 2},
|
||||
{5, 6,
|
||||
7, 8}});
|
||||
{V(5), V(6),
|
||||
V(7), V(8)}});
|
||||
|
||||
RunTest<T>(axis, {}, input, outputs, false); //TensorRT parser: Assertion failed: axis != BATCH_DIM
|
||||
RunTest<T>(axis, {}, input, outputs,
|
||||
// TensorRT parser: Assertion failed: axis != BATCH_DIM
|
||||
false, // is_tensorrt_supported
|
||||
false, // expect_failure
|
||||
use_opset_13); // split_as_input
|
||||
}
|
||||
|
||||
TEST(SplitOperatorTest, Axis0EqualSplitInt8) {
|
||||
SplitTestInt<int8_t>();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(SplitOperatorTest, Axis0EqualSplitInt32) {
|
||||
SplitTestInt<int32_t>();
|
||||
}
|
||||
|
||||
TEST(SplitOperatorTest, Axis0EqualSplitInt64) {
|
||||
SplitTestInt<int64_t>();
|
||||
}
|
||||
|
||||
TEST(SplitOperatorTest, Axis0EqualSplitString) {
|
||||
constexpr int64_t axis = 0;
|
||||
std::vector<ShapeAndStringData> outputs;
|
||||
|
||||
// input shape and data
|
||||
ShapeAndStringData input = {{4, 2}, // shape
|
||||
{"a", "b",
|
||||
"c", "d",
|
||||
"e", "f",
|
||||
"g", "h"}};
|
||||
|
||||
outputs.push_back({{2, 2},
|
||||
{"a", "b",
|
||||
"c", "d"}});
|
||||
|
||||
outputs.push_back({{2, 2},
|
||||
{"e", "f",
|
||||
"g", "h"}});
|
||||
|
||||
RunTest<std::string>(axis, {}, input, outputs, false); //TensorRT parser: Assertion failed: axis != BATCH_DIM
|
||||
TEST(SplitOperatorTest, Axis0EqualSplit) {
|
||||
SplitTestAxis0EqualSplit<float>();
|
||||
SplitTestAxis0EqualSplit<double>();
|
||||
SplitTestAxis0EqualSplit<MLFloat16>();
|
||||
SplitTestAxis0EqualSplit<BFloat16>(true); // BFloat16 added in opset 13
|
||||
SplitTestAxis0EqualSplit<int8_t>();
|
||||
SplitTestAxis0EqualSplit<int16_t>();
|
||||
SplitTestAxis0EqualSplit<int32_t>();
|
||||
SplitTestAxis0EqualSplit<int64_t>();
|
||||
SplitTestAxis0EqualSplit<uint8_t>();
|
||||
SplitTestAxis0EqualSplit<uint16_t>();
|
||||
SplitTestAxis0EqualSplit<uint32_t>();
|
||||
SplitTestAxis0EqualSplit<uint64_t>();
|
||||
SplitTestAxis0EqualSplit<bool>();
|
||||
SplitTestAxis0EqualSplit<std::string>();
|
||||
}
|
||||
|
||||
TEST(SplitOperatorTest, Axis0UnequalSplitFloat) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue