Add type reduction support to Slice and Transpose (#6547)

* Add type reduction support to Slice and Transpose
This commit is contained in:
Scott McKay 2021-02-05 11:08:23 +10:00 committed by GitHub
parent 89627a8178
commit c49d1dbc4b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 286 additions and 165 deletions

View file

@ -116,7 +116,7 @@ constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<BFloat16
function<int8_t>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
function<uint8_t>(__VA_ARGS__); \
function<uint8_t>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
function<int16_t>(__VA_ARGS__); \

View file

@ -4,6 +4,8 @@
#include "core/providers/cpu/tensor/slice.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/common.h"
#include "core/providers/op_kernel_type_control.h"
#include "core/providers/op_kernel_type_control_utils.h"
#include <unordered_map>
#include <limits>
@ -11,40 +13,28 @@ using namespace ::onnxruntime::common;
using namespace std;
namespace onnxruntime {
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Slice,
1, 9,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
Slice1);
namespace op_kernel_type_control {
// we're using one set of types for all opsets
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Slice, Input, 0,
ORT_OP_KERNEL_TYPE_CTRL_ALL_TENSOR_DATA_TYPES);
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Slice,
10, 10,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
Slice10);
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Slice, Input, 1, int32_t, int64_t);
} // namespace op_kernel_type_control
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Slice,
11,
12,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
Slice10);
ONNX_CPU_OPERATOR_KERNEL(
Slice,
13,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
Slice10);
namespace {
using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
Slice, Input, 0);
using EnabledIndicesTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
Slice, Input, 1);
const std::vector<MLDataType> dataTypeConstraints =
BuildKernelDefConstraintsFunctorFromTypeList<EnabledDataTypes>{}();
const std::vector<MLDataType> indicesTypeConstraints =
BuildKernelDefConstraintsFunctorFromTypeList<EnabledIndicesTypes>{}();
// std::clamp doesn't exist until C++17 so create a local version
template <typename T>
const T& clamp(const T& v, const T& lo, const T& hi) {
@ -54,6 +44,37 @@ const T& clamp(const T& v, const T& lo, const T& hi) {
}
} // namespace
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Slice,
1, 9,
KernelDefBuilder().TypeConstraint("T", dataTypeConstraints),
Slice1);
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Slice,
10, 10,
KernelDefBuilder()
.TypeConstraint("T", dataTypeConstraints)
.TypeConstraint("Tind", indicesTypeConstraints),
Slice10);
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Slice,
11,
12,
KernelDefBuilder()
.TypeConstraint("T", dataTypeConstraints)
.TypeConstraint("Tind", indicesTypeConstraints),
Slice10);
ONNX_CPU_OPERATOR_KERNEL(
Slice,
13,
KernelDefBuilder()
.TypeConstraint("T", dataTypeConstraints)
.TypeConstraint("Tind", indicesTypeConstraints),
Slice10);
// Check if it's possible to combine innermost dimensions so we copy larger blocks.
// Sets flattened_output_dims to nullptr if it is not.
// Updates starts and steps to match flattened_output_dims if it is.
@ -218,19 +239,21 @@ Status SliceBase::PrepareForCompute(const std::vector<int64_t>& raw_starts,
}
// Slice V10 & DynamicSlice
void SliceBase::FillVectorsFromInput(const Tensor& start_tensor,
const Tensor& ends_tensor,
const Tensor* axes_tensor,
const Tensor* steps_tensor,
std::vector<int64_t>& input_starts,
std::vector<int64_t>& input_ends,
std::vector<int64_t>& input_axes,
std::vector<int64_t>& input_steps) {
ORT_ENFORCE(start_tensor.Shape().NumDimensions() == 1, "Starts must be a 1-D array");
ORT_ENFORCE(ends_tensor.Shape().NumDimensions() == 1, "Ends must be a 1-D array");
ORT_ENFORCE(start_tensor.Shape() == ends_tensor.Shape(), "Starts and ends shape mismatch");
ORT_ENFORCE(nullptr == axes_tensor || start_tensor.Shape() == axes_tensor->Shape(), "Starts and axes shape mismatch");
ORT_ENFORCE(nullptr == steps_tensor || start_tensor.Shape() == steps_tensor->Shape(), "Starts and steps shape mismatch");
Status SliceBase::FillVectorsFromInput(const Tensor& start_tensor,
const Tensor& ends_tensor,
const Tensor* axes_tensor,
const Tensor* steps_tensor,
std::vector<int64_t>& input_starts,
std::vector<int64_t>& input_ends,
std::vector<int64_t>& input_axes,
std::vector<int64_t>& input_steps) {
ORT_RETURN_IF_NOT(start_tensor.Shape().NumDimensions() == 1, "Starts must be a 1-D array");
ORT_RETURN_IF_NOT(ends_tensor.Shape().NumDimensions() == 1, "Ends must be a 1-D array");
ORT_RETURN_IF_NOT(start_tensor.Shape() == ends_tensor.Shape(), "Starts and ends shape mismatch");
ORT_RETURN_IF_NOT(nullptr == axes_tensor || start_tensor.Shape() == axes_tensor->Shape(),
"Starts and axes shape mismatch");
ORT_RETURN_IF_NOT(nullptr == steps_tensor || start_tensor.Shape() == steps_tensor->Shape(),
"Starts and steps shape mismatch");
const auto& size = start_tensor.Shape().Size();
input_starts.resize(size);
@ -241,7 +264,11 @@ void SliceBase::FillVectorsFromInput(const Tensor& start_tensor,
if (nullptr != steps_tensor)
input_steps.resize(size);
if (start_tensor.IsDataType<int32_t>()) {
// check for type reduction of supported indices types
constexpr bool int32_enabled = utils::HasType<EnabledIndicesTypes, int32_t>();
constexpr bool int64_enabled = utils::HasType<EnabledIndicesTypes, int64_t>();
if (int32_enabled && start_tensor.IsDataType<int32_t>()) {
std::copy(start_tensor.Data<int32_t>(), start_tensor.Data<int32_t>() + size, input_starts.begin());
std::copy(ends_tensor.Data<int32_t>(), ends_tensor.Data<int32_t>() + size, input_ends.begin());
if (nullptr != axes_tensor)
@ -251,7 +278,7 @@ void SliceBase::FillVectorsFromInput(const Tensor& start_tensor,
std::copy(steps_tensor->Data<int32_t>(), steps_tensor->Data<int32_t>() + size, input_steps.begin());
}
else if (start_tensor.IsDataType<int64_t>()) {
else if (int64_enabled && start_tensor.IsDataType<int64_t>()) {
std::copy(start_tensor.Data<int64_t>(), start_tensor.Data<int64_t>() + size, input_starts.begin());
std::copy(ends_tensor.Data<int64_t>(), ends_tensor.Data<int64_t>() + size, input_ends.begin());
if (nullptr != axes_tensor)
@ -261,10 +288,13 @@ void SliceBase::FillVectorsFromInput(const Tensor& start_tensor,
std::copy(steps_tensor->Data<int64_t>(), steps_tensor->Data<int64_t>() + size, input_steps.begin());
}
// should not reach this as no kernel is registered for this condition to be triggered - just an additional safety check
else {
ORT_THROW("Data type for starts and ends inputs' need to be int32_t or int64_t, but instead got ", start_tensor.DataType());
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"Data type for starts and ends inputs' is not supported in this build. Got ",
start_tensor.DataType());
}
return Status::OK();
}
template <typename T>
@ -304,22 +334,40 @@ static Status SliceImpl(OpKernelContext* ctx,
flattened_input_dims.back() = compute_metadata.p_flattened_output_dims_->back();
TensorShape input_shape(std::move(flattened_input_dims));
auto input_iterator = SliceIterator<T>(input_tensor, input_shape, compute_metadata.starts_, *compute_metadata.p_flattened_output_dims_, compute_metadata.steps_);
auto input_iterator = SliceIterator<T>(input_tensor, input_shape, compute_metadata.starts_,
*compute_metadata.p_flattened_output_dims_, compute_metadata.steps_);
create_output(input_iterator);
} else {
auto input_iterator = SliceIterator<T>(input_tensor, compute_metadata.starts_, compute_metadata.output_dims_, compute_metadata.steps_);
auto input_iterator = SliceIterator<T>(input_tensor, compute_metadata.starts_, compute_metadata.output_dims_,
compute_metadata.steps_);
create_output(input_iterator);
}
return Status::OK();
}
template <typename EnabledTypes, typename T>
static inline bool CallSliceImplIfEnabled(OpKernelContext* ctx,
const Tensor& input_tensor,
SliceOp::PrepareForComputeMetadata& compute_metadata,
Status& status) {
constexpr bool enabled = utils::HasTypeWithSameSize<EnabledTypes, T>();
if (enabled) {
status = SliceImpl<T>(ctx, input_tensor, compute_metadata);
}
return enabled;
}
Status SliceBase::Compute(OpKernelContext* ctx) const {
const auto* input_tensor_ptr = ctx->Input<Tensor>(0);
ORT_ENFORCE(input_tensor_ptr != nullptr, "Missing input tensor to be processed");
const auto& input_tensor = *input_tensor_ptr;
const auto& input_dimensions = input_tensor.Shape().GetDims();
if (input_dimensions.empty()) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Cannot slice scalars");
if (input_dimensions.empty()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Cannot slice scalars");
}
SliceOp::PrepareForComputeMetadata compute_metadata(input_dimensions);
// Slice V10 & DynamicSlice
@ -328,8 +376,9 @@ Status SliceBase::Compute(OpKernelContext* ctx) const {
std::vector<int64_t> input_ends;
std::vector<int64_t> input_axes;
std::vector<int64_t> input_steps;
FillVectorsFromInput(*ctx->Input<Tensor>(1), *ctx->Input<Tensor>(2), ctx->Input<Tensor>(3),
ctx->Input<Tensor>(4), input_starts, input_ends, input_axes, input_steps);
ORT_RETURN_IF_ERROR(FillVectorsFromInput(*ctx->Input<Tensor>(1), *ctx->Input<Tensor>(2),
ctx->Input<Tensor>(3), ctx->Input<Tensor>(4),
input_starts, input_ends, input_axes, input_steps));
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps, compute_metadata));
}
@ -340,28 +389,38 @@ Status SliceBase::Compute(OpKernelContext* ctx) const {
Status status = Status::OK();
bool supported = false;
if (input_tensor.IsDataTypeString()) {
status = SliceImpl<std::string>(ctx, input_tensor, compute_metadata);
if (utils::HasType<EnabledDataTypes, std::string>()) {
supported = true;
status = SliceImpl<std::string>(ctx, input_tensor, compute_metadata);
}
} else {
const auto element_size = input_tensor.DataType()->Size();
// call SliceImpl
switch (element_size) {
case sizeof(uint32_t):
status = SliceImpl<uint32_t>(ctx, input_tensor, compute_metadata);
supported = CallSliceImplIfEnabled<EnabledDataTypes, uint32_t>(ctx, input_tensor, compute_metadata, status);
break;
case sizeof(uint64_t):
status = SliceImpl<uint64_t>(ctx, input_tensor, compute_metadata);
supported = CallSliceImplIfEnabled<EnabledDataTypes, uint64_t>(ctx, input_tensor, compute_metadata, status);
break;
case sizeof(uint16_t):
status = SliceImpl<uint16_t>(ctx, input_tensor, compute_metadata);
supported = CallSliceImplIfEnabled<EnabledDataTypes, uint16_t>(ctx, input_tensor, compute_metadata, status);
break;
case sizeof(uint8_t):
status = SliceImpl<uint8_t>(ctx, input_tensor, compute_metadata);
supported = CallSliceImplIfEnabled<EnabledDataTypes, uint8_t>(ctx, input_tensor, compute_metadata, status);
break;
default:
ORT_THROW("Unsupported input data type of ", input_tensor.DataType());
// leave 'supported' as false
break;
}
}
if (!supported) {
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported input data type of ", input_tensor.DataType());
}
return status;
}

View file

@ -44,14 +44,14 @@ class SliceBase {
SliceOp::PrepareForComputeMetadata& compute_metadata);
// Slice V10 & DynamicSlice
static void FillVectorsFromInput(const Tensor& start_tensor,
const Tensor& ends_tensor,
const Tensor* axes_tensor,
const Tensor* steps_tensor,
std::vector<int64_t>& input_starts,
std::vector<int64_t>& input_ends,
std::vector<int64_t>& input_axes,
std::vector<int64_t>& input_steps);
static Status FillVectorsFromInput(const Tensor& start_tensor,
const Tensor& ends_tensor,
const Tensor* axes_tensor,
const Tensor* steps_tensor,
std::vector<int64_t>& input_starts,
std::vector<int64_t>& input_ends,
std::vector<int64_t>& input_axes,
std::vector<int64_t>& input_steps);
protected:
SliceBase(const OpKernelInfo& info, bool dynamic = false)

View file

@ -2,12 +2,30 @@
// Licensed under the MIT License.
#include "core/providers/cpu/tensor/transpose.h"
#include "core/framework/utils.h"
#include "core/mlas/inc/mlas.h"
#include "core/providers/op_kernel_type_control.h"
#include "core/providers/op_kernel_type_control_utils.h"
#include "utils.h"
namespace onnxruntime {
namespace op_kernel_type_control {
// we're using one set of types for all opsets
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Transpose, Input, 0,
ORT_OP_KERNEL_TYPE_CTRL_ALL_TENSOR_DATA_TYPES);
} // namespace op_kernel_type_control
namespace {
// reduce the supported types with any global or op specific lists
using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
Transpose, Input, 0);
const std::vector<MLDataType> dataTypeConstraints = BuildKernelDefConstraintsFunctorFromTypeList<EnabledDataTypes>{}();
} // namespace
/* A permutation [a,b,c,...] indicates that
- The 0-th dimension of the output corresponds to the a-th dimension of input
- The 1-st dimension of the output corresponds to the b-th dimension of input
@ -152,42 +170,54 @@ inline void CopyPrim(uint8_t* target, const uint8_t* source) {
// The function does not check num_axes > 0 but this is expected.
template <class T>
static void TypedDoTransposeEltWise(int64_t num_axes, const std::vector<int64_t>& target_dims, size_t num_blocks,
static bool TypedDoTransposeEltWise(int64_t num_axes, const std::vector<int64_t>& target_dims, size_t num_blocks,
const std::vector<size_t>& stride, const uint8_t* source, uint8_t* target) {
MultiIndex mindex;
IncrementIndexAndComputeOffsetSetup(mindex, num_axes, target_dims, stride, sizeof(T));
constexpr bool enabled = utils::HasTypeWithSameSize<EnabledDataTypes, T>();
const uint8_t* local_source = source;
uint8_t* target_end = target + sizeof(T) * num_blocks;
for (; target != target_end; target += sizeof(T)) {
ORT_ENFORCE((local_source >= source) && (local_source < source + sizeof(T) * num_blocks));
CopyPrim<T>(target, local_source);
IncrementIndexAndComputeOffset(mindex, local_source);
if (enabled) {
MultiIndex mindex;
IncrementIndexAndComputeOffsetSetup(mindex, num_axes, target_dims, stride, sizeof(T));
const uint8_t* local_source = source;
uint8_t* target_end = target + sizeof(T) * num_blocks;
for (; target != target_end; target += sizeof(T)) {
ORT_ENFORCE((local_source >= source) && (local_source < source + sizeof(T) * num_blocks));
CopyPrim<T>(target, local_source);
IncrementIndexAndComputeOffset(mindex, local_source);
}
}
return enabled;
}
// DoTransposeEltWise: specialization of DoTranspose for the num_elts_in_block=1 case.
// copies source tensor to target, transposing elements.
// The stride vector indicates the transposition.
void DoTransposeEltWise(int64_t num_axes, const std::vector<int64_t>& target_dims, size_t num_blocks,
const std::vector<size_t>& stride, const uint8_t* source, uint8_t* target,
size_t element_size) {
Status DoTransposeEltWise(int64_t num_axes, const std::vector<int64_t>& target_dims, size_t num_blocks,
const std::vector<size_t>& stride, const uint8_t* source, uint8_t* target,
size_t element_size) {
bool enabled = false;
switch (element_size) {
case sizeof(uint64_t):
TypedDoTransposeEltWise<uint64_t>(num_axes, target_dims, num_blocks, stride, source, target);
enabled = TypedDoTransposeEltWise<uint64_t>(num_axes, target_dims, num_blocks, stride, source, target);
break;
case sizeof(uint32_t):
TypedDoTransposeEltWise<uint32_t>(num_axes, target_dims, num_blocks, stride, source, target);
enabled = TypedDoTransposeEltWise<uint32_t>(num_axes, target_dims, num_blocks, stride, source, target);
break;
case sizeof(uint16_t):
TypedDoTransposeEltWise<uint16_t>(num_axes, target_dims, num_blocks, stride, source, target);
enabled = TypedDoTransposeEltWise<uint16_t>(num_axes, target_dims, num_blocks, stride, source, target);
break;
case sizeof(uint8_t):
TypedDoTransposeEltWise<uint8_t>(num_axes, target_dims, num_blocks, stride, source, target);
enabled = TypedDoTransposeEltWise<uint8_t>(num_axes, target_dims, num_blocks, stride, source, target);
break;
default:
assert(false);
// leave enabled as false
break;
}
return enabled ? Status::OK()
: ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Transpose of element size not supported in this build. Size=",
element_size);
}
static void DoTransposeEltWise(int64_t num_axes, const std::vector<int64_t>& target_dims, size_t num_blocks,
@ -243,17 +273,25 @@ static Status DoUntypedTranspose(const std::vector<size_t>& permutations, const
}
}
Status status = Status::OK();
if (is_string_type) {
const auto* input_data = input.template Data<std::string>();
auto* output_data = output.template MutableData<std::string>();
if (1 == prefix_blocksize) {
DoTransposeSingleBlock(suffix_blocksize, input_data, output_data);
} else if (1 == suffix_blocksize) {
DoTransposeEltWise(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, stride,
input_data, output_data);
constexpr bool string_enabled = utils::HasType<EnabledDataTypes, std::string>();
if (string_enabled) {
const auto* input_data = input.template Data<std::string>();
auto* output_data = output.template MutableData<std::string>();
if (1 == prefix_blocksize) {
DoTransposeSingleBlock(suffix_blocksize, input_data, output_data);
} else if (1 == suffix_blocksize) {
DoTransposeEltWise(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, stride,
input_data, output_data);
} else {
DoTransposeImpl(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, suffix_blocksize, stride,
input_data, output_data);
}
} else {
DoTransposeImpl(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, suffix_blocksize, stride,
input_data, output_data);
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Transpose of std::string is not supported in this build.");
}
} else {
const auto* input_data = reinterpret_cast<const uint8_t*>(input.DataRaw());
@ -261,15 +299,16 @@ static Status DoUntypedTranspose(const std::vector<size_t>& permutations, const
if (1 == prefix_blocksize) {
DoTransposeSingleBlock(suffix_blocksize, input_data, output_data, element_size);
} else if (1 == suffix_blocksize) {
DoTransposeEltWise(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, stride,
input_data, output_data, element_size);
// this may return a failed status if the data size is not supported in this build
status = DoTransposeEltWise(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, stride,
input_data, output_data, element_size);
} else {
DoTransposeImpl(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, suffix_blocksize, stride,
input_data, output_data, element_size);
}
}
return Status::OK();
return status;
}
/*
@ -686,13 +725,13 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Transpose,
1,
12,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
KernelDefBuilder().TypeConstraint("T", dataTypeConstraints),
Transpose);
ONNX_CPU_OPERATOR_KERNEL(
Transpose,
13,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
KernelDefBuilder().TypeConstraint("T", dataTypeConstraints),
Transpose);
} // namespace onnxruntime

View file

@ -16,9 +16,10 @@ namespace onnxruntime {
*/
bool IsTransposeReshape(const std::vector<size_t>& perm, const std::vector<int64_t>& input_dims);
void DoTransposeEltWise(int64_t num_axes, const std::vector<int64_t>& target_dims, size_t num_blocks,
const std::vector<size_t>& stride, const uint8_t* source, uint8_t* target,
size_t element_size);
// Public function for element-wise transpose, primarily to unit test any out of bounds access
Status DoTransposeEltWise(int64_t num_axes, const std::vector<int64_t>& target_dims, size_t num_blocks,
const std::vector<size_t>& stride, const uint8_t* source, uint8_t* target,
size_t element_size);
class TransposeBase {
public:

View file

@ -62,14 +62,15 @@ struct GlobalAllowed {};
} // namespace tags
// optionally holds a list of types associated with a tag class
// if types are defined, the data member 'types' should contain them in a type list
// otherwise, if no types are defined (distinct from an empty list of types), there should be no data member 'types'
// if types are defined, the type alias member called 'types' should contain them in a type list
// (e.g. using something like std::tuple or a boost::mp11::mp_list)
// otherwise, if no types are defined (distinct from an empty list of types), there should be no 'types' type alias
// see the tags in onnxruntime::op_kernel_type_control::tags for intended uses
template <typename Tag>
struct TypesHolder {};
/**
* Provides a type list of enabled types via the 'types' data member.
* Provides a type list of enabled types via the 'types' type alias member.
* Enabled types are the set intersection of supported and allowed types.
*
* @tparam SupportedTypesHolder A 'TypesHolder' with a list of supported types.
@ -84,15 +85,15 @@ struct EnabledTypes {
template <typename T>
using GetTypesMember = typename T::types;
// checks whether T has data member 'types'
// checks whether T has a type alias member called 'types'
template <typename T>
using HasTypesMember = boost::mp11::mp_valid<GetTypesMember, T>;
static_assert(HasTypesMember<SupportedTypesHolder>::value,
"SupportedTypesHolder must have a 'types' data member.");
"SupportedTypesHolder must have a type alias called 'types'.");
// the allowed type lists to consider
// for each element of AllowedTypesHolders, get and include a 'types' data member if present
// for each element of AllowedTypesHolders, get and include the 'types' type alias member if present
using AllowedTypesMembers =
boost::mp11::mp_transform<
GetTypesMember,
@ -105,17 +106,10 @@ struct EnabledTypes {
boost::mp11::mp_push_front<AllowedTypesMembers, GetTypesMember<SupportedTypesHolder>>;
static_assert(boost::mp11::mp_all_of<TypeListsToConsider, boost::mp11::mp_is_list>::value,
"All 'types' data members must be type lists.");
// converts type list L into a type set (type list with unique elements)
template <typename L>
using MakeSet =
boost::mp11::mp_apply<
boost::mp11::mp_set_push_back,
boost::mp11::mp_append<TypeList<TypeList<>>, L>>;
"All 'types' type aliases must be type lists.");
// type lists converted to type sets
using TypeSetsToConsider = boost::mp11::mp_transform<MakeSet, TypeListsToConsider>;
using TypeSetsToConsider = boost::mp11::mp_transform<boost::mp11::mp_unique, TypeListsToConsider>;
public:
using types = boost::mp11::mp_apply<boost::mp11::mp_set_intersection, TypeSetsToConsider>;
@ -237,37 +231,6 @@ struct EnabledTypes {
::onnxruntime::op_kernel_type_control::kAllOpSets, \
ArgDirection, ArgIndex)
/**
* std::tuple type with the enabled types for a given Op kernel argument.
*
* @param OpProvider The Op provider.
* @param OpDomain The Op domain.
* @param OpName The Op name.
* @param OpSet The opset to use for the supported types list.
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
* @param ArgIndex Index of the given Op kernel argument.
*/
#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE( \
OpProvider, OpDomain, OpName, OpSet, ArgDirection, ArgIndex) \
::boost::mp11::mp_rename< \
ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( \
OpProvider, OpDomain, OpName, OpSet, ArgDirection, ArgIndex, SupportedTypeList), \
std::tuple>
/**
* std::tuple type with the enabled types for a given Op kernel argument that is valid for all opsets.
*
* @param OpProvider The Op provider.
* @param OpDomain The Op domain.
* @param OpName The Op name.
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
* @param ArgIndex Index of the given Op kernel argument.
*/
#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE_ALL_OPSETS( \
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex) \
ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE(OpProvider, OpDomain, OpName, \
::onnxruntime::op_kernel_type_control::kAllOpSets, \
ArgDirection, ArgIndex)
/**
* Usage example:
*
@ -277,7 +240,7 @@ struct EnabledTypes {
* namespace op_kernel_type_control {
* // specify supported types, i.e., the full set of types that can be enabled
* ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
* MyProvider, DomainContainingMyOp, MyOp, Input, 0,
* MyProvider, DomainContainingMyOp, MyOp, OpSet, Input, 0,
* int, float, double);
* } // namespace op_kernel_type_control
* } // namespace onnxruntime

View file

@ -0,0 +1,50 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "boost/mp11.hpp"
#include "core/framework/data_types.h"
namespace onnxruntime {
namespace utils {
/**
* Check if the set of types contains the specified type.
*/
template <typename TypeSet, typename T>
constexpr bool HasType() {
static_assert(boost::mp11::mp_is_set<TypeSet>::value, "TypeSet must be a type set.");
return boost::mp11::mp_set_contains<TypeSet, T>::value;
}
template <typename T>
using SizeOfT = boost::mp11::mp_size_t<sizeof(T)>;
/**
* Check if the set of types contains a type with the same size as T.
*
* @remarks e.g. will return true if T is int32_t and the list contains any 4 byte type (i.e. sizeof(int32_t))
* such as int32_t, uint32_t or float.
*/
template <typename TypeSet, typename T>
constexpr bool HasTypeWithSameSize() {
static_assert(boost::mp11::mp_is_set<TypeSet>::value, "TypeSet must be a type set.");
using EnabledTypeSizes = boost::mp11::mp_unique<boost::mp11::mp_transform<SizeOfT, TypeSet>>;
return boost::mp11::mp_set_contains<EnabledTypeSizes, SizeOfT<T>>::value;
}
} // namespace utils
} // namespace onnxruntime
/** Data types that are used in DataTypeImpl::AllTensorTypes()
*/
#define ORT_OP_KERNEL_TYPE_CTRL_ALL_TENSOR_DATA_TYPES \
bool, \
float, double, \
uint8_t, uint16_t, uint32_t, uint64_t, \
int8_t, int16_t, int32_t, int64_t, \
MLFloat16, BFloat16, \
std::string

View file

@ -5,6 +5,7 @@
#include "test/providers/provider_test_utils.h"
#include "test/providers/compare_provider_test_utils.h"
#include "core/providers/cpu/tensor/transpose.h"
#include "test/util/include/asserts.h"
namespace onnxruntime {
namespace test {
@ -560,9 +561,9 @@ TEST(TransposeOpTest, DoTransposeEltWise) {
13.0f, 15.0f, 14.0f, 16.0f,
17.0f, 17.0f};
DoTransposeEltWise(input_shape.size(), input_shape, 16,
stride, (uint8_t*)input_vals_end.data(), (uint8_t*)target.data(),
sizeof(float));
ASSERT_STATUS_OK(DoTransposeEltWise(input_shape.size(), input_shape, 16,
stride, (uint8_t*)input_vals_end.data(), (uint8_t*)target.data(),
sizeof(float)));
for (size_t i = 0; i < input_vals_end.size(); ++i) {
ASSERT_TRUE(target[i] == expected_vals3[i]);
}

View file

@ -1,4 +1,4 @@
The mnist model is used in a multiple tests for minimal/mobile builds in both ONNX and ORT formats.
The mnist model is used in multiple tests for minimal/mobile builds in both ONNX and ORT formats.
We also save both ONNX and ORT format versions of the model with level 1 (aka 'basic') optimizations applied.
- mnist.level1_opt.onnx makes sure the required operators for this model are automatically included in

View file

@ -109,13 +109,18 @@ class DefaultTypeUsageProcessor(TypeUsageProcessor):
def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
for i in self._input_types.keys():
if i >= node.InputsLength():
raise RuntimeError('Node has {} inputs. Tracker for {} incorrectly configured as it requires {}.'
.format(node.InputsLength(), self.name, i))
type_str = value_name_to_typestr(node.Inputs(i), value_name_to_typeinfo)
self._input_types[i].add(type_str)
# Some operators have fewer inputs in earlier versions where data that was as an attribute
# become an input in later versions to allow it to be dynamically provided. Allow for that.
# e.g. Slice-1 had attributes for the indices, and Slice-10 moved those to be inputs
# raise RuntimeError('Node has {} outputs. Tracker for {} incorrectly configured as it requires {}.'
# .format(node.OutputsLength(), self.name, o))
pass
else:
type_str = value_name_to_typestr(node.Inputs(i), value_name_to_typeinfo)
self._input_types[i].add(type_str)
for o in self._output_types.keys():
# Don't know of any ops where the number of outputs changed across versions, so require a valid length
if o >= node.OutputsLength():
raise RuntimeError('Node has {} outputs. Tracker for {} incorrectly configured as it requires {}.'
.format(node.OutputsLength(), self.name, o))
@ -127,7 +132,7 @@ class DefaultTypeUsageProcessor(TypeUsageProcessor):
if 0 not in self._input_types.keys():
# currently all standard typed registrations are for input 0.
# custom registrations can be handled by operator specific processors (e.g. OneHotProcessor below).
raise RuntimeError('Expected typed registration to use type from input 0.')
raise RuntimeError('Expected typed registration to use type from input 0. Node:{}'.format(self.name))
return type_in_registration in self._input_types[0]
@ -254,8 +259,8 @@ def _create_operator_type_usage_processors():
# - some known large kernels
#
# Ops we are ignoring currently so as not to produce meaningless/unused output:
# - Implementation is not type specific:
# If, Loop, Reshape, Scan, Shape, Squeeze, Unsqueeze
# - Implementation is type agnostic:
# DynamicQuantizeMatMul, If, Loop, Reshape, Scan, Shape, Squeeze, Unsqueeze
# - Only one type supported in the ORT implementation:
# FusedConv, FusedGemm, FusedMatMul, TransposeMatMul
# - Implementation does not have any significant type specific code:
@ -264,7 +269,7 @@ def _create_operator_type_usage_processors():
'DequantizeLinear', 'Div', 'Equal', 'Exp', 'Expand',
'Gemm', 'Greater', 'Less', 'MatMul', 'Max', 'Min', 'Mul',
'NonMaxSuppression', 'NonZero', 'Pad', 'Range', 'Relu', 'Resize',
'Sigmoid', 'Slice', 'Softmax', 'Split', 'Sub', 'Tile', 'TopK', 'Transpose']
'Sigmoid', 'Softmax', 'Split', 'Sub', 'Tile', 'TopK', 'Transpose']
internal_ops = ['QLinearAdd', 'QLinearMul']
@ -286,12 +291,15 @@ def _create_operator_type_usage_processors():
#
# Operators that require custom handling
#
add(DefaultTypeUsageProcessor('ai.onnx', 'Cast', inputs=[0], outputs=[0])) # track input0 and output0
# Cast switches on types of input 0 and output 0
add(DefaultTypeUsageProcessor('ai.onnx', 'Cast', inputs=[0], outputs=[0]))
# Operators that switch on the type of input 0 and 1
add(DefaultTypeUsageProcessor('ai.onnx', 'Gather', inputs=[0, 1]))
add(DefaultTypeUsageProcessor('ai.onnx', 'GatherElements', inputs=[0, 1]))
add(DefaultTypeUsageProcessor('ai.onnx', 'Pow', inputs=[0, 1]))
add(DefaultTypeUsageProcessor('ai.onnx', 'Slice', inputs=[0, 1]))
# Operators that switch on output type
add(DefaultTypeUsageProcessor('ai.onnx', 'ConstantOfShape', inputs=[], outputs=[0]))