mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-27 22:45:57 +00:00
Add type reduction support to Slice and Transpose (#6547)
* Add type reduction support to Slice and Transpose
This commit is contained in:
parent
89627a8178
commit
c49d1dbc4b
10 changed files with 286 additions and 165 deletions
|
|
@ -116,7 +116,7 @@ constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<BFloat16
|
|||
function<int8_t>(__VA_ARGS__); \
|
||||
break; \
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
|
||||
function<uint8_t>(__VA_ARGS__); \
|
||||
function<uint8_t>(__VA_ARGS__); \
|
||||
break; \
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
|
||||
function<int16_t>(__VA_ARGS__); \
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
#include "core/providers/cpu/tensor/slice.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/op_kernel_type_control.h"
|
||||
#include "core/providers/op_kernel_type_control_utils.h"
|
||||
#include <unordered_map>
|
||||
#include <limits>
|
||||
|
||||
|
|
@ -11,40 +13,28 @@ using namespace ::onnxruntime::common;
|
|||
using namespace std;
|
||||
|
||||
namespace onnxruntime {
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
Slice,
|
||||
1, 9,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
|
||||
Slice1);
|
||||
namespace op_kernel_type_control {
|
||||
// we're using one set of types for all opsets
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, Slice, Input, 0,
|
||||
ORT_OP_KERNEL_TYPE_CTRL_ALL_TENSOR_DATA_TYPES);
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
Slice,
|
||||
10, 10,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
.TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
Slice10);
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, Slice, Input, 1, int32_t, int64_t);
|
||||
} // namespace op_kernel_type_control
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
Slice,
|
||||
11,
|
||||
12,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
.TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
Slice10);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Slice,
|
||||
13,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
.TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
Slice10);
|
||||
namespace {
|
||||
using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
|
||||
Slice, Input, 0);
|
||||
using EnabledIndicesTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
|
||||
Slice, Input, 1);
|
||||
|
||||
const std::vector<MLDataType> dataTypeConstraints =
|
||||
BuildKernelDefConstraintsFunctorFromTypeList<EnabledDataTypes>{}();
|
||||
|
||||
const std::vector<MLDataType> indicesTypeConstraints =
|
||||
BuildKernelDefConstraintsFunctorFromTypeList<EnabledIndicesTypes>{}();
|
||||
|
||||
// std::clamp doesn't exist until C++17 so create a local version
|
||||
template <typename T>
|
||||
const T& clamp(const T& v, const T& lo, const T& hi) {
|
||||
|
|
@ -54,6 +44,37 @@ const T& clamp(const T& v, const T& lo, const T& hi) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
Slice,
|
||||
1, 9,
|
||||
KernelDefBuilder().TypeConstraint("T", dataTypeConstraints),
|
||||
Slice1);
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
Slice,
|
||||
10, 10,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", dataTypeConstraints)
|
||||
.TypeConstraint("Tind", indicesTypeConstraints),
|
||||
Slice10);
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
Slice,
|
||||
11,
|
||||
12,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", dataTypeConstraints)
|
||||
.TypeConstraint("Tind", indicesTypeConstraints),
|
||||
Slice10);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Slice,
|
||||
13,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", dataTypeConstraints)
|
||||
.TypeConstraint("Tind", indicesTypeConstraints),
|
||||
Slice10);
|
||||
|
||||
// Check if it's possible to combine innermost dimensions so we copy larger blocks.
|
||||
// Sets flattened_output_dims to nullptr if it is not.
|
||||
// Updates starts and steps to match flattened_output_dims if it is.
|
||||
|
|
@ -218,19 +239,21 @@ Status SliceBase::PrepareForCompute(const std::vector<int64_t>& raw_starts,
|
|||
}
|
||||
|
||||
// Slice V10 & DynamicSlice
|
||||
void SliceBase::FillVectorsFromInput(const Tensor& start_tensor,
|
||||
const Tensor& ends_tensor,
|
||||
const Tensor* axes_tensor,
|
||||
const Tensor* steps_tensor,
|
||||
std::vector<int64_t>& input_starts,
|
||||
std::vector<int64_t>& input_ends,
|
||||
std::vector<int64_t>& input_axes,
|
||||
std::vector<int64_t>& input_steps) {
|
||||
ORT_ENFORCE(start_tensor.Shape().NumDimensions() == 1, "Starts must be a 1-D array");
|
||||
ORT_ENFORCE(ends_tensor.Shape().NumDimensions() == 1, "Ends must be a 1-D array");
|
||||
ORT_ENFORCE(start_tensor.Shape() == ends_tensor.Shape(), "Starts and ends shape mismatch");
|
||||
ORT_ENFORCE(nullptr == axes_tensor || start_tensor.Shape() == axes_tensor->Shape(), "Starts and axes shape mismatch");
|
||||
ORT_ENFORCE(nullptr == steps_tensor || start_tensor.Shape() == steps_tensor->Shape(), "Starts and steps shape mismatch");
|
||||
Status SliceBase::FillVectorsFromInput(const Tensor& start_tensor,
|
||||
const Tensor& ends_tensor,
|
||||
const Tensor* axes_tensor,
|
||||
const Tensor* steps_tensor,
|
||||
std::vector<int64_t>& input_starts,
|
||||
std::vector<int64_t>& input_ends,
|
||||
std::vector<int64_t>& input_axes,
|
||||
std::vector<int64_t>& input_steps) {
|
||||
ORT_RETURN_IF_NOT(start_tensor.Shape().NumDimensions() == 1, "Starts must be a 1-D array");
|
||||
ORT_RETURN_IF_NOT(ends_tensor.Shape().NumDimensions() == 1, "Ends must be a 1-D array");
|
||||
ORT_RETURN_IF_NOT(start_tensor.Shape() == ends_tensor.Shape(), "Starts and ends shape mismatch");
|
||||
ORT_RETURN_IF_NOT(nullptr == axes_tensor || start_tensor.Shape() == axes_tensor->Shape(),
|
||||
"Starts and axes shape mismatch");
|
||||
ORT_RETURN_IF_NOT(nullptr == steps_tensor || start_tensor.Shape() == steps_tensor->Shape(),
|
||||
"Starts and steps shape mismatch");
|
||||
|
||||
const auto& size = start_tensor.Shape().Size();
|
||||
input_starts.resize(size);
|
||||
|
|
@ -241,7 +264,11 @@ void SliceBase::FillVectorsFromInput(const Tensor& start_tensor,
|
|||
if (nullptr != steps_tensor)
|
||||
input_steps.resize(size);
|
||||
|
||||
if (start_tensor.IsDataType<int32_t>()) {
|
||||
// check for type reduction of supported indices types
|
||||
constexpr bool int32_enabled = utils::HasType<EnabledIndicesTypes, int32_t>();
|
||||
constexpr bool int64_enabled = utils::HasType<EnabledIndicesTypes, int64_t>();
|
||||
|
||||
if (int32_enabled && start_tensor.IsDataType<int32_t>()) {
|
||||
std::copy(start_tensor.Data<int32_t>(), start_tensor.Data<int32_t>() + size, input_starts.begin());
|
||||
std::copy(ends_tensor.Data<int32_t>(), ends_tensor.Data<int32_t>() + size, input_ends.begin());
|
||||
if (nullptr != axes_tensor)
|
||||
|
|
@ -251,7 +278,7 @@ void SliceBase::FillVectorsFromInput(const Tensor& start_tensor,
|
|||
std::copy(steps_tensor->Data<int32_t>(), steps_tensor->Data<int32_t>() + size, input_steps.begin());
|
||||
}
|
||||
|
||||
else if (start_tensor.IsDataType<int64_t>()) {
|
||||
else if (int64_enabled && start_tensor.IsDataType<int64_t>()) {
|
||||
std::copy(start_tensor.Data<int64_t>(), start_tensor.Data<int64_t>() + size, input_starts.begin());
|
||||
std::copy(ends_tensor.Data<int64_t>(), ends_tensor.Data<int64_t>() + size, input_ends.begin());
|
||||
if (nullptr != axes_tensor)
|
||||
|
|
@ -261,10 +288,13 @@ void SliceBase::FillVectorsFromInput(const Tensor& start_tensor,
|
|||
std::copy(steps_tensor->Data<int64_t>(), steps_tensor->Data<int64_t>() + size, input_steps.begin());
|
||||
}
|
||||
|
||||
// should not reach this as no kernel is registered for this condition to be triggered - just an additional safety check
|
||||
else {
|
||||
ORT_THROW("Data type for starts and ends inputs' need to be int32_t or int64_t, but instead got ", start_tensor.DataType());
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
||||
"Data type for starts and ends inputs' is not supported in this build. Got ",
|
||||
start_tensor.DataType());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -304,22 +334,40 @@ static Status SliceImpl(OpKernelContext* ctx,
|
|||
flattened_input_dims.back() = compute_metadata.p_flattened_output_dims_->back();
|
||||
TensorShape input_shape(std::move(flattened_input_dims));
|
||||
|
||||
auto input_iterator = SliceIterator<T>(input_tensor, input_shape, compute_metadata.starts_, *compute_metadata.p_flattened_output_dims_, compute_metadata.steps_);
|
||||
auto input_iterator = SliceIterator<T>(input_tensor, input_shape, compute_metadata.starts_,
|
||||
*compute_metadata.p_flattened_output_dims_, compute_metadata.steps_);
|
||||
create_output(input_iterator);
|
||||
} else {
|
||||
auto input_iterator = SliceIterator<T>(input_tensor, compute_metadata.starts_, compute_metadata.output_dims_, compute_metadata.steps_);
|
||||
auto input_iterator = SliceIterator<T>(input_tensor, compute_metadata.starts_, compute_metadata.output_dims_,
|
||||
compute_metadata.steps_);
|
||||
create_output(input_iterator);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename EnabledTypes, typename T>
|
||||
static inline bool CallSliceImplIfEnabled(OpKernelContext* ctx,
|
||||
const Tensor& input_tensor,
|
||||
SliceOp::PrepareForComputeMetadata& compute_metadata,
|
||||
Status& status) {
|
||||
constexpr bool enabled = utils::HasTypeWithSameSize<EnabledTypes, T>();
|
||||
if (enabled) {
|
||||
status = SliceImpl<T>(ctx, input_tensor, compute_metadata);
|
||||
}
|
||||
|
||||
return enabled;
|
||||
}
|
||||
|
||||
Status SliceBase::Compute(OpKernelContext* ctx) const {
|
||||
const auto* input_tensor_ptr = ctx->Input<Tensor>(0);
|
||||
ORT_ENFORCE(input_tensor_ptr != nullptr, "Missing input tensor to be processed");
|
||||
const auto& input_tensor = *input_tensor_ptr;
|
||||
const auto& input_dimensions = input_tensor.Shape().GetDims();
|
||||
if (input_dimensions.empty()) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Cannot slice scalars");
|
||||
|
||||
if (input_dimensions.empty()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Cannot slice scalars");
|
||||
}
|
||||
|
||||
SliceOp::PrepareForComputeMetadata compute_metadata(input_dimensions);
|
||||
|
||||
// Slice V10 & DynamicSlice
|
||||
|
|
@ -328,8 +376,9 @@ Status SliceBase::Compute(OpKernelContext* ctx) const {
|
|||
std::vector<int64_t> input_ends;
|
||||
std::vector<int64_t> input_axes;
|
||||
std::vector<int64_t> input_steps;
|
||||
FillVectorsFromInput(*ctx->Input<Tensor>(1), *ctx->Input<Tensor>(2), ctx->Input<Tensor>(3),
|
||||
ctx->Input<Tensor>(4), input_starts, input_ends, input_axes, input_steps);
|
||||
ORT_RETURN_IF_ERROR(FillVectorsFromInput(*ctx->Input<Tensor>(1), *ctx->Input<Tensor>(2),
|
||||
ctx->Input<Tensor>(3), ctx->Input<Tensor>(4),
|
||||
input_starts, input_ends, input_axes, input_steps));
|
||||
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps, compute_metadata));
|
||||
}
|
||||
|
|
@ -340,28 +389,38 @@ Status SliceBase::Compute(OpKernelContext* ctx) const {
|
|||
|
||||
Status status = Status::OK();
|
||||
|
||||
bool supported = false;
|
||||
if (input_tensor.IsDataTypeString()) {
|
||||
status = SliceImpl<std::string>(ctx, input_tensor, compute_metadata);
|
||||
if (utils::HasType<EnabledDataTypes, std::string>()) {
|
||||
supported = true;
|
||||
status = SliceImpl<std::string>(ctx, input_tensor, compute_metadata);
|
||||
}
|
||||
} else {
|
||||
const auto element_size = input_tensor.DataType()->Size();
|
||||
|
||||
// call SliceImpl
|
||||
switch (element_size) {
|
||||
case sizeof(uint32_t):
|
||||
status = SliceImpl<uint32_t>(ctx, input_tensor, compute_metadata);
|
||||
supported = CallSliceImplIfEnabled<EnabledDataTypes, uint32_t>(ctx, input_tensor, compute_metadata, status);
|
||||
break;
|
||||
case sizeof(uint64_t):
|
||||
status = SliceImpl<uint64_t>(ctx, input_tensor, compute_metadata);
|
||||
supported = CallSliceImplIfEnabled<EnabledDataTypes, uint64_t>(ctx, input_tensor, compute_metadata, status);
|
||||
break;
|
||||
case sizeof(uint16_t):
|
||||
status = SliceImpl<uint16_t>(ctx, input_tensor, compute_metadata);
|
||||
supported = CallSliceImplIfEnabled<EnabledDataTypes, uint16_t>(ctx, input_tensor, compute_metadata, status);
|
||||
break;
|
||||
case sizeof(uint8_t):
|
||||
status = SliceImpl<uint8_t>(ctx, input_tensor, compute_metadata);
|
||||
supported = CallSliceImplIfEnabled<EnabledDataTypes, uint8_t>(ctx, input_tensor, compute_metadata, status);
|
||||
break;
|
||||
default:
|
||||
ORT_THROW("Unsupported input data type of ", input_tensor.DataType());
|
||||
// leave 'supported' as false
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!supported) {
|
||||
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported input data type of ", input_tensor.DataType());
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -44,14 +44,14 @@ class SliceBase {
|
|||
SliceOp::PrepareForComputeMetadata& compute_metadata);
|
||||
|
||||
// Slice V10 & DynamicSlice
|
||||
static void FillVectorsFromInput(const Tensor& start_tensor,
|
||||
const Tensor& ends_tensor,
|
||||
const Tensor* axes_tensor,
|
||||
const Tensor* steps_tensor,
|
||||
std::vector<int64_t>& input_starts,
|
||||
std::vector<int64_t>& input_ends,
|
||||
std::vector<int64_t>& input_axes,
|
||||
std::vector<int64_t>& input_steps);
|
||||
static Status FillVectorsFromInput(const Tensor& start_tensor,
|
||||
const Tensor& ends_tensor,
|
||||
const Tensor* axes_tensor,
|
||||
const Tensor* steps_tensor,
|
||||
std::vector<int64_t>& input_starts,
|
||||
std::vector<int64_t>& input_ends,
|
||||
std::vector<int64_t>& input_axes,
|
||||
std::vector<int64_t>& input_steps);
|
||||
|
||||
protected:
|
||||
SliceBase(const OpKernelInfo& info, bool dynamic = false)
|
||||
|
|
|
|||
|
|
@ -2,12 +2,30 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cpu/tensor/transpose.h"
|
||||
|
||||
#include "core/framework/utils.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#include "core/providers/op_kernel_type_control.h"
|
||||
#include "core/providers/op_kernel_type_control_utils.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace op_kernel_type_control {
|
||||
// we're using one set of types for all opsets
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, Transpose, Input, 0,
|
||||
ORT_OP_KERNEL_TYPE_CTRL_ALL_TENSOR_DATA_TYPES);
|
||||
} // namespace op_kernel_type_control
|
||||
|
||||
namespace {
|
||||
// reduce the supported types with any global or op specific lists
|
||||
using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
|
||||
Transpose, Input, 0);
|
||||
|
||||
const std::vector<MLDataType> dataTypeConstraints = BuildKernelDefConstraintsFunctorFromTypeList<EnabledDataTypes>{}();
|
||||
} // namespace
|
||||
|
||||
/* A permutation [a,b,c,...] indicates that
|
||||
- The 0-th dimension of the output corresponds to the a-th dimension of input
|
||||
- The 1-st dimension of the output corresponds to the b-th dimension of input
|
||||
|
|
@ -152,42 +170,54 @@ inline void CopyPrim(uint8_t* target, const uint8_t* source) {
|
|||
|
||||
// The function does not check num_axes > 0 but this is expected.
|
||||
template <class T>
|
||||
static void TypedDoTransposeEltWise(int64_t num_axes, const std::vector<int64_t>& target_dims, size_t num_blocks,
|
||||
static bool TypedDoTransposeEltWise(int64_t num_axes, const std::vector<int64_t>& target_dims, size_t num_blocks,
|
||||
const std::vector<size_t>& stride, const uint8_t* source, uint8_t* target) {
|
||||
MultiIndex mindex;
|
||||
IncrementIndexAndComputeOffsetSetup(mindex, num_axes, target_dims, stride, sizeof(T));
|
||||
constexpr bool enabled = utils::HasTypeWithSameSize<EnabledDataTypes, T>();
|
||||
|
||||
const uint8_t* local_source = source;
|
||||
uint8_t* target_end = target + sizeof(T) * num_blocks;
|
||||
for (; target != target_end; target += sizeof(T)) {
|
||||
ORT_ENFORCE((local_source >= source) && (local_source < source + sizeof(T) * num_blocks));
|
||||
CopyPrim<T>(target, local_source);
|
||||
IncrementIndexAndComputeOffset(mindex, local_source);
|
||||
if (enabled) {
|
||||
MultiIndex mindex;
|
||||
IncrementIndexAndComputeOffsetSetup(mindex, num_axes, target_dims, stride, sizeof(T));
|
||||
|
||||
const uint8_t* local_source = source;
|
||||
uint8_t* target_end = target + sizeof(T) * num_blocks;
|
||||
for (; target != target_end; target += sizeof(T)) {
|
||||
ORT_ENFORCE((local_source >= source) && (local_source < source + sizeof(T) * num_blocks));
|
||||
CopyPrim<T>(target, local_source);
|
||||
IncrementIndexAndComputeOffset(mindex, local_source);
|
||||
}
|
||||
}
|
||||
|
||||
return enabled;
|
||||
}
|
||||
|
||||
// DoTransposeEltWise: specialization of DoTranspose for the num_elts_in_block=1 case.
|
||||
// copies source tensor to target, transposing elements.
|
||||
// The stride vector indicates the transposition.
|
||||
void DoTransposeEltWise(int64_t num_axes, const std::vector<int64_t>& target_dims, size_t num_blocks,
|
||||
const std::vector<size_t>& stride, const uint8_t* source, uint8_t* target,
|
||||
size_t element_size) {
|
||||
Status DoTransposeEltWise(int64_t num_axes, const std::vector<int64_t>& target_dims, size_t num_blocks,
|
||||
const std::vector<size_t>& stride, const uint8_t* source, uint8_t* target,
|
||||
size_t element_size) {
|
||||
bool enabled = false;
|
||||
switch (element_size) {
|
||||
case sizeof(uint64_t):
|
||||
TypedDoTransposeEltWise<uint64_t>(num_axes, target_dims, num_blocks, stride, source, target);
|
||||
enabled = TypedDoTransposeEltWise<uint64_t>(num_axes, target_dims, num_blocks, stride, source, target);
|
||||
break;
|
||||
case sizeof(uint32_t):
|
||||
TypedDoTransposeEltWise<uint32_t>(num_axes, target_dims, num_blocks, stride, source, target);
|
||||
enabled = TypedDoTransposeEltWise<uint32_t>(num_axes, target_dims, num_blocks, stride, source, target);
|
||||
break;
|
||||
case sizeof(uint16_t):
|
||||
TypedDoTransposeEltWise<uint16_t>(num_axes, target_dims, num_blocks, stride, source, target);
|
||||
enabled = TypedDoTransposeEltWise<uint16_t>(num_axes, target_dims, num_blocks, stride, source, target);
|
||||
break;
|
||||
case sizeof(uint8_t):
|
||||
TypedDoTransposeEltWise<uint8_t>(num_axes, target_dims, num_blocks, stride, source, target);
|
||||
enabled = TypedDoTransposeEltWise<uint8_t>(num_axes, target_dims, num_blocks, stride, source, target);
|
||||
break;
|
||||
default:
|
||||
assert(false);
|
||||
// leave enabled as false
|
||||
break;
|
||||
}
|
||||
|
||||
return enabled ? Status::OK()
|
||||
: ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Transpose of element size not supported in this build. Size=",
|
||||
element_size);
|
||||
}
|
||||
|
||||
static void DoTransposeEltWise(int64_t num_axes, const std::vector<int64_t>& target_dims, size_t num_blocks,
|
||||
|
|
@ -243,17 +273,25 @@ static Status DoUntypedTranspose(const std::vector<size_t>& permutations, const
|
|||
}
|
||||
}
|
||||
|
||||
Status status = Status::OK();
|
||||
|
||||
if (is_string_type) {
|
||||
const auto* input_data = input.template Data<std::string>();
|
||||
auto* output_data = output.template MutableData<std::string>();
|
||||
if (1 == prefix_blocksize) {
|
||||
DoTransposeSingleBlock(suffix_blocksize, input_data, output_data);
|
||||
} else if (1 == suffix_blocksize) {
|
||||
DoTransposeEltWise(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, stride,
|
||||
input_data, output_data);
|
||||
constexpr bool string_enabled = utils::HasType<EnabledDataTypes, std::string>();
|
||||
|
||||
if (string_enabled) {
|
||||
const auto* input_data = input.template Data<std::string>();
|
||||
auto* output_data = output.template MutableData<std::string>();
|
||||
if (1 == prefix_blocksize) {
|
||||
DoTransposeSingleBlock(suffix_blocksize, input_data, output_data);
|
||||
} else if (1 == suffix_blocksize) {
|
||||
DoTransposeEltWise(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, stride,
|
||||
input_data, output_data);
|
||||
} else {
|
||||
DoTransposeImpl(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, suffix_blocksize, stride,
|
||||
input_data, output_data);
|
||||
}
|
||||
} else {
|
||||
DoTransposeImpl(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, suffix_blocksize, stride,
|
||||
input_data, output_data);
|
||||
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Transpose of std::string is not supported in this build.");
|
||||
}
|
||||
} else {
|
||||
const auto* input_data = reinterpret_cast<const uint8_t*>(input.DataRaw());
|
||||
|
|
@ -261,15 +299,16 @@ static Status DoUntypedTranspose(const std::vector<size_t>& permutations, const
|
|||
if (1 == prefix_blocksize) {
|
||||
DoTransposeSingleBlock(suffix_blocksize, input_data, output_data, element_size);
|
||||
} else if (1 == suffix_blocksize) {
|
||||
DoTransposeEltWise(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, stride,
|
||||
input_data, output_data, element_size);
|
||||
// this may return a failed status if the data size is not supported in this build
|
||||
status = DoTransposeEltWise(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, stride,
|
||||
input_data, output_data, element_size);
|
||||
} else {
|
||||
DoTransposeImpl(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, suffix_blocksize, stride,
|
||||
input_data, output_data, element_size);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
return status;
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
@ -686,13 +725,13 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
|||
Transpose,
|
||||
1,
|
||||
12,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
|
||||
KernelDefBuilder().TypeConstraint("T", dataTypeConstraints),
|
||||
Transpose);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Transpose,
|
||||
13,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
|
||||
KernelDefBuilder().TypeConstraint("T", dataTypeConstraints),
|
||||
Transpose);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -16,9 +16,10 @@ namespace onnxruntime {
|
|||
*/
|
||||
bool IsTransposeReshape(const std::vector<size_t>& perm, const std::vector<int64_t>& input_dims);
|
||||
|
||||
void DoTransposeEltWise(int64_t num_axes, const std::vector<int64_t>& target_dims, size_t num_blocks,
|
||||
const std::vector<size_t>& stride, const uint8_t* source, uint8_t* target,
|
||||
size_t element_size);
|
||||
// Public function for element-wise transpose, primarily to unit test any out of bounds access
|
||||
Status DoTransposeEltWise(int64_t num_axes, const std::vector<int64_t>& target_dims, size_t num_blocks,
|
||||
const std::vector<size_t>& stride, const uint8_t* source, uint8_t* target,
|
||||
size_t element_size);
|
||||
|
||||
class TransposeBase {
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -62,14 +62,15 @@ struct GlobalAllowed {};
|
|||
} // namespace tags
|
||||
|
||||
// optionally holds a list of types associated with a tag class
|
||||
// if types are defined, the data member 'types' should contain them in a type list
|
||||
// otherwise, if no types are defined (distinct from an empty list of types), there should be no data member 'types'
|
||||
// if types are defined, the type alias member called 'types' should contain them in a type list
|
||||
// (e.g. using something like std::tuple or a boost::mp11::mp_list)
|
||||
// otherwise, if no types are defined (distinct from an empty list of types), there should be no 'types' type alias
|
||||
// see the tags in onnxruntime::op_kernel_type_control::tags for intended uses
|
||||
template <typename Tag>
|
||||
struct TypesHolder {};
|
||||
|
||||
/**
|
||||
* Provides a type list of enabled types via the 'types' data member.
|
||||
* Provides a type list of enabled types via the 'types' type alias member.
|
||||
* Enabled types are the set intersection of supported and allowed types.
|
||||
*
|
||||
* @tparam SupportedTypesHolder A 'TypesHolder' with a list of supported types.
|
||||
|
|
@ -84,15 +85,15 @@ struct EnabledTypes {
|
|||
template <typename T>
|
||||
using GetTypesMember = typename T::types;
|
||||
|
||||
// checks whether T has data member 'types'
|
||||
// checks whether T has a type alias member called 'types'
|
||||
template <typename T>
|
||||
using HasTypesMember = boost::mp11::mp_valid<GetTypesMember, T>;
|
||||
|
||||
static_assert(HasTypesMember<SupportedTypesHolder>::value,
|
||||
"SupportedTypesHolder must have a 'types' data member.");
|
||||
"SupportedTypesHolder must have a type alias called 'types'.");
|
||||
|
||||
// the allowed type lists to consider
|
||||
// for each element of AllowedTypesHolders, get and include a 'types' data member if present
|
||||
// for each element of AllowedTypesHolders, get and include the 'types' type alias member if present
|
||||
using AllowedTypesMembers =
|
||||
boost::mp11::mp_transform<
|
||||
GetTypesMember,
|
||||
|
|
@ -105,17 +106,10 @@ struct EnabledTypes {
|
|||
boost::mp11::mp_push_front<AllowedTypesMembers, GetTypesMember<SupportedTypesHolder>>;
|
||||
|
||||
static_assert(boost::mp11::mp_all_of<TypeListsToConsider, boost::mp11::mp_is_list>::value,
|
||||
"All 'types' data members must be type lists.");
|
||||
|
||||
// converts type list L into a type set (type list with unique elements)
|
||||
template <typename L>
|
||||
using MakeSet =
|
||||
boost::mp11::mp_apply<
|
||||
boost::mp11::mp_set_push_back,
|
||||
boost::mp11::mp_append<TypeList<TypeList<>>, L>>;
|
||||
"All 'types' type aliases must be type lists.");
|
||||
|
||||
// type lists converted to type sets
|
||||
using TypeSetsToConsider = boost::mp11::mp_transform<MakeSet, TypeListsToConsider>;
|
||||
using TypeSetsToConsider = boost::mp11::mp_transform<boost::mp11::mp_unique, TypeListsToConsider>;
|
||||
|
||||
public:
|
||||
using types = boost::mp11::mp_apply<boost::mp11::mp_set_intersection, TypeSetsToConsider>;
|
||||
|
|
@ -237,37 +231,6 @@ struct EnabledTypes {
|
|||
::onnxruntime::op_kernel_type_control::kAllOpSets, \
|
||||
ArgDirection, ArgIndex)
|
||||
|
||||
/**
|
||||
* std::tuple type with the enabled types for a given Op kernel argument.
|
||||
*
|
||||
* @param OpProvider The Op provider.
|
||||
* @param OpDomain The Op domain.
|
||||
* @param OpName The Op name.
|
||||
* @param OpSet The opset to use for the supported types list.
|
||||
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
|
||||
* @param ArgIndex Index of the given Op kernel argument.
|
||||
*/
|
||||
#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE( \
|
||||
OpProvider, OpDomain, OpName, OpSet, ArgDirection, ArgIndex) \
|
||||
::boost::mp11::mp_rename< \
|
||||
ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( \
|
||||
OpProvider, OpDomain, OpName, OpSet, ArgDirection, ArgIndex, SupportedTypeList), \
|
||||
std::tuple>
|
||||
|
||||
/**
|
||||
* std::tuple type with the enabled types for a given Op kernel argument that is valid for all opsets.
|
||||
*
|
||||
* @param OpProvider The Op provider.
|
||||
* @param OpDomain The Op domain.
|
||||
* @param OpName The Op name.
|
||||
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
|
||||
* @param ArgIndex Index of the given Op kernel argument.
|
||||
*/
|
||||
#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE_ALL_OPSETS( \
|
||||
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex) \
|
||||
ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE(OpProvider, OpDomain, OpName, \
|
||||
::onnxruntime::op_kernel_type_control::kAllOpSets, \
|
||||
ArgDirection, ArgIndex)
|
||||
/**
|
||||
* Usage example:
|
||||
*
|
||||
|
|
@ -277,7 +240,7 @@ struct EnabledTypes {
|
|||
* namespace op_kernel_type_control {
|
||||
* // specify supported types, i.e., the full set of types that can be enabled
|
||||
* ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
|
||||
* MyProvider, DomainContainingMyOp, MyOp, Input, 0,
|
||||
* MyProvider, DomainContainingMyOp, MyOp, OpSet, Input, 0,
|
||||
* int, float, double);
|
||||
* } // namespace op_kernel_type_control
|
||||
* } // namespace onnxruntime
|
||||
|
|
|
|||
50
onnxruntime/core/providers/op_kernel_type_control_utils.h
Normal file
50
onnxruntime/core/providers/op_kernel_type_control_utils.h
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "boost/mp11.hpp"
|
||||
|
||||
#include "core/framework/data_types.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace utils {
|
||||
/**
|
||||
* Check if the set of types contains the specified type.
|
||||
*/
|
||||
template <typename TypeSet, typename T>
|
||||
constexpr bool HasType() {
|
||||
static_assert(boost::mp11::mp_is_set<TypeSet>::value, "TypeSet must be a type set.");
|
||||
|
||||
return boost::mp11::mp_set_contains<TypeSet, T>::value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using SizeOfT = boost::mp11::mp_size_t<sizeof(T)>;
|
||||
|
||||
/**
|
||||
* Check if the set of types contains a type with the same size as T.
|
||||
*
|
||||
* @remarks e.g. will return true if T is int32_t and the list contains any 4 byte type (i.e. sizeof(int32_t))
|
||||
* such as int32_t, uint32_t or float.
|
||||
*/
|
||||
template <typename TypeSet, typename T>
|
||||
constexpr bool HasTypeWithSameSize() {
|
||||
static_assert(boost::mp11::mp_is_set<TypeSet>::value, "TypeSet must be a type set.");
|
||||
|
||||
using EnabledTypeSizes = boost::mp11::mp_unique<boost::mp11::mp_transform<SizeOfT, TypeSet>>;
|
||||
return boost::mp11::mp_set_contains<EnabledTypeSizes, SizeOfT<T>>::value;
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace onnxruntime
|
||||
|
||||
/** Data types that are used in DataTypeImpl::AllTensorTypes()
|
||||
*/
|
||||
#define ORT_OP_KERNEL_TYPE_CTRL_ALL_TENSOR_DATA_TYPES \
|
||||
bool, \
|
||||
float, double, \
|
||||
uint8_t, uint16_t, uint32_t, uint64_t, \
|
||||
int8_t, int16_t, int32_t, int64_t, \
|
||||
MLFloat16, BFloat16, \
|
||||
std::string
|
||||
|
|
@ -5,6 +5,7 @@
|
|||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/providers/compare_provider_test_utils.h"
|
||||
#include "core/providers/cpu/tensor/transpose.h"
|
||||
#include "test/util/include/asserts.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
|
@ -560,9 +561,9 @@ TEST(TransposeOpTest, DoTransposeEltWise) {
|
|||
13.0f, 15.0f, 14.0f, 16.0f,
|
||||
17.0f, 17.0f};
|
||||
|
||||
DoTransposeEltWise(input_shape.size(), input_shape, 16,
|
||||
stride, (uint8_t*)input_vals_end.data(), (uint8_t*)target.data(),
|
||||
sizeof(float));
|
||||
ASSERT_STATUS_OK(DoTransposeEltWise(input_shape.size(), input_shape, 16,
|
||||
stride, (uint8_t*)input_vals_end.data(), (uint8_t*)target.data(),
|
||||
sizeof(float)));
|
||||
for (size_t i = 0; i < input_vals_end.size(); ++i) {
|
||||
ASSERT_TRUE(target[i] == expected_vals3[i]);
|
||||
}
|
||||
|
|
|
|||
2
onnxruntime/test/testdata/mnist.readme.txt
vendored
2
onnxruntime/test/testdata/mnist.readme.txt
vendored
|
|
@ -1,4 +1,4 @@
|
|||
The mnist model is used in a multiple tests for minimal/mobile builds in both ONNX and ORT formats.
|
||||
The mnist model is used in multiple tests for minimal/mobile builds in both ONNX and ORT formats.
|
||||
|
||||
We also save both ONNX and ORT format versions of the model with level 1 (aka 'basic') optimizations applied.
|
||||
- mnist.level1_opt.onnx makes sure the required operators for this model are automatically included in
|
||||
|
|
|
|||
|
|
@ -109,13 +109,18 @@ class DefaultTypeUsageProcessor(TypeUsageProcessor):
|
|||
def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
|
||||
for i in self._input_types.keys():
|
||||
if i >= node.InputsLength():
|
||||
raise RuntimeError('Node has {} inputs. Tracker for {} incorrectly configured as it requires {}.'
|
||||
.format(node.InputsLength(), self.name, i))
|
||||
|
||||
type_str = value_name_to_typestr(node.Inputs(i), value_name_to_typeinfo)
|
||||
self._input_types[i].add(type_str)
|
||||
# Some operators have fewer inputs in earlier versions where data that was as an attribute
|
||||
# become an input in later versions to allow it to be dynamically provided. Allow for that.
|
||||
# e.g. Slice-1 had attributes for the indices, and Slice-10 moved those to be inputs
|
||||
# raise RuntimeError('Node has {} outputs. Tracker for {} incorrectly configured as it requires {}.'
|
||||
# .format(node.OutputsLength(), self.name, o))
|
||||
pass
|
||||
else:
|
||||
type_str = value_name_to_typestr(node.Inputs(i), value_name_to_typeinfo)
|
||||
self._input_types[i].add(type_str)
|
||||
|
||||
for o in self._output_types.keys():
|
||||
# Don't know of any ops where the number of outputs changed across versions, so require a valid length
|
||||
if o >= node.OutputsLength():
|
||||
raise RuntimeError('Node has {} outputs. Tracker for {} incorrectly configured as it requires {}.'
|
||||
.format(node.OutputsLength(), self.name, o))
|
||||
|
|
@ -127,7 +132,7 @@ class DefaultTypeUsageProcessor(TypeUsageProcessor):
|
|||
if 0 not in self._input_types.keys():
|
||||
# currently all standard typed registrations are for input 0.
|
||||
# custom registrations can be handled by operator specific processors (e.g. OneHotProcessor below).
|
||||
raise RuntimeError('Expected typed registration to use type from input 0.')
|
||||
raise RuntimeError('Expected typed registration to use type from input 0. Node:{}'.format(self.name))
|
||||
|
||||
return type_in_registration in self._input_types[0]
|
||||
|
||||
|
|
@ -254,8 +259,8 @@ def _create_operator_type_usage_processors():
|
|||
# - some known large kernels
|
||||
#
|
||||
# Ops we are ignoring currently so as not to produce meaningless/unused output:
|
||||
# - Implementation is not type specific:
|
||||
# If, Loop, Reshape, Scan, Shape, Squeeze, Unsqueeze
|
||||
# - Implementation is type agnostic:
|
||||
# DynamicQuantizeMatMul, If, Loop, Reshape, Scan, Shape, Squeeze, Unsqueeze
|
||||
# - Only one type supported in the ORT implementation:
|
||||
# FusedConv, FusedGemm, FusedMatMul, TransposeMatMul
|
||||
# - Implementation does not have any significant type specific code:
|
||||
|
|
@ -264,7 +269,7 @@ def _create_operator_type_usage_processors():
|
|||
'DequantizeLinear', 'Div', 'Equal', 'Exp', 'Expand',
|
||||
'Gemm', 'Greater', 'Less', 'MatMul', 'Max', 'Min', 'Mul',
|
||||
'NonMaxSuppression', 'NonZero', 'Pad', 'Range', 'Relu', 'Resize',
|
||||
'Sigmoid', 'Slice', 'Softmax', 'Split', 'Sub', 'Tile', 'TopK', 'Transpose']
|
||||
'Sigmoid', 'Softmax', 'Split', 'Sub', 'Tile', 'TopK', 'Transpose']
|
||||
|
||||
internal_ops = ['QLinearAdd', 'QLinearMul']
|
||||
|
||||
|
|
@ -286,12 +291,15 @@ def _create_operator_type_usage_processors():
|
|||
#
|
||||
# Operators that require custom handling
|
||||
#
|
||||
add(DefaultTypeUsageProcessor('ai.onnx', 'Cast', inputs=[0], outputs=[0])) # track input0 and output0
|
||||
|
||||
# Cast switches on types of input 0 and output 0
|
||||
add(DefaultTypeUsageProcessor('ai.onnx', 'Cast', inputs=[0], outputs=[0]))
|
||||
|
||||
# Operators that switch on the type of input 0 and 1
|
||||
add(DefaultTypeUsageProcessor('ai.onnx', 'Gather', inputs=[0, 1]))
|
||||
add(DefaultTypeUsageProcessor('ai.onnx', 'GatherElements', inputs=[0, 1]))
|
||||
add(DefaultTypeUsageProcessor('ai.onnx', 'Pow', inputs=[0, 1]))
|
||||
add(DefaultTypeUsageProcessor('ai.onnx', 'Slice', inputs=[0, 1]))
|
||||
|
||||
# Operators that switch on output type
|
||||
add(DefaultTypeUsageProcessor('ai.onnx', 'ConstantOfShape', inputs=[], outputs=[0]))
|
||||
|
|
|
|||
Loading…
Reference in a new issue