diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h index c195deaee0..7dfd5da467 100644 --- a/include/onnxruntime/core/framework/data_types_internal.h +++ b/include/onnxruntime/core/framework/data_types_internal.h @@ -116,7 +116,7 @@ constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType(__VA_ARGS__); \ break; \ case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \ - function(__VA_ARGS__); \ + function(__VA_ARGS__); \ break; \ case ONNX_NAMESPACE::TensorProto_DataType_INT16: \ function(__VA_ARGS__); \ diff --git a/onnxruntime/core/providers/cpu/tensor/slice.cc b/onnxruntime/core/providers/cpu/tensor/slice.cc index a95d6a5510..52784823f5 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice.cc +++ b/onnxruntime/core/providers/cpu/tensor/slice.cc @@ -4,6 +4,8 @@ #include "core/providers/cpu/tensor/slice.h" #include "core/providers/cpu/tensor/utils.h" #include "core/providers/common.h" +#include "core/providers/op_kernel_type_control.h" +#include "core/providers/op_kernel_type_control_utils.h" #include #include @@ -11,40 +13,28 @@ using namespace ::onnxruntime::common; using namespace std; namespace onnxruntime { -ONNX_CPU_OPERATOR_VERSIONED_KERNEL( - Slice, - 1, 9, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()), - Slice1); +namespace op_kernel_type_control { +// we're using one set of types for all opsets +ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS( + kCpuExecutionProvider, kOnnxDomain, Slice, Input, 0, + ORT_OP_KERNEL_TYPE_CTRL_ALL_TENSOR_DATA_TYPES); -ONNX_CPU_OPERATOR_VERSIONED_KERNEL( - Slice, - 10, 10, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) - .TypeConstraint("Tind", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), - Slice10); +ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS( + kCpuExecutionProvider, kOnnxDomain, Slice, Input, 1, int32_t, int64_t); +} // namespace op_kernel_type_control -ONNX_CPU_OPERATOR_VERSIONED_KERNEL( - Slice, - 11, - 12, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) - .TypeConstraint("Tind", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), - Slice10); - -ONNX_CPU_OPERATOR_KERNEL( - Slice, - 13, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) - .TypeConstraint("Tind", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), - Slice10); namespace { +using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain, + Slice, Input, 0); +using EnabledIndicesTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain, + Slice, Input, 1); + +const std::vector dataTypeConstraints = + BuildKernelDefConstraintsFunctorFromTypeList{}(); + +const std::vector indicesTypeConstraints = + BuildKernelDefConstraintsFunctorFromTypeList{}(); + // std::clamp doesn't exist until C++17 so create a local version template const T& clamp(const T& v, const T& lo, const T& hi) { @@ -54,6 +44,37 @@ const T& clamp(const T& v, const T& lo, const T& hi) { } } // namespace +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( + Slice, + 1, 9, + KernelDefBuilder().TypeConstraint("T", dataTypeConstraints), + Slice1); + +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( + Slice, + 10, 10, + KernelDefBuilder() + .TypeConstraint("T", dataTypeConstraints) + .TypeConstraint("Tind", indicesTypeConstraints), + Slice10); + +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( + Slice, + 11, + 12, + KernelDefBuilder() + .TypeConstraint("T", dataTypeConstraints) + .TypeConstraint("Tind", indicesTypeConstraints), + Slice10); + +ONNX_CPU_OPERATOR_KERNEL( + Slice, + 13, + KernelDefBuilder() + .TypeConstraint("T", dataTypeConstraints) + .TypeConstraint("Tind", indicesTypeConstraints), + Slice10); + // Check if it's possible to combine innermost dimensions so we copy larger blocks. // Sets flattened_output_dims to nullptr if it is not. // Updates starts and steps to match flattened_output_dims if it is. @@ -218,19 +239,21 @@ Status SliceBase::PrepareForCompute(const std::vector& raw_starts, } // Slice V10 & DynamicSlice -void SliceBase::FillVectorsFromInput(const Tensor& start_tensor, - const Tensor& ends_tensor, - const Tensor* axes_tensor, - const Tensor* steps_tensor, - std::vector& input_starts, - std::vector& input_ends, - std::vector& input_axes, - std::vector& input_steps) { - ORT_ENFORCE(start_tensor.Shape().NumDimensions() == 1, "Starts must be a 1-D array"); - ORT_ENFORCE(ends_tensor.Shape().NumDimensions() == 1, "Ends must be a 1-D array"); - ORT_ENFORCE(start_tensor.Shape() == ends_tensor.Shape(), "Starts and ends shape mismatch"); - ORT_ENFORCE(nullptr == axes_tensor || start_tensor.Shape() == axes_tensor->Shape(), "Starts and axes shape mismatch"); - ORT_ENFORCE(nullptr == steps_tensor || start_tensor.Shape() == steps_tensor->Shape(), "Starts and steps shape mismatch"); +Status SliceBase::FillVectorsFromInput(const Tensor& start_tensor, + const Tensor& ends_tensor, + const Tensor* axes_tensor, + const Tensor* steps_tensor, + std::vector& input_starts, + std::vector& input_ends, + std::vector& input_axes, + std::vector& input_steps) { + ORT_RETURN_IF_NOT(start_tensor.Shape().NumDimensions() == 1, "Starts must be a 1-D array"); + ORT_RETURN_IF_NOT(ends_tensor.Shape().NumDimensions() == 1, "Ends must be a 1-D array"); + ORT_RETURN_IF_NOT(start_tensor.Shape() == ends_tensor.Shape(), "Starts and ends shape mismatch"); + ORT_RETURN_IF_NOT(nullptr == axes_tensor || start_tensor.Shape() == axes_tensor->Shape(), + "Starts and axes shape mismatch"); + ORT_RETURN_IF_NOT(nullptr == steps_tensor || start_tensor.Shape() == steps_tensor->Shape(), + "Starts and steps shape mismatch"); const auto& size = start_tensor.Shape().Size(); input_starts.resize(size); @@ -241,7 +264,11 @@ void SliceBase::FillVectorsFromInput(const Tensor& start_tensor, if (nullptr != steps_tensor) input_steps.resize(size); - if (start_tensor.IsDataType()) { + // check for type reduction of supported indices types + constexpr bool int32_enabled = utils::HasType(); + constexpr bool int64_enabled = utils::HasType(); + + if (int32_enabled && start_tensor.IsDataType()) { std::copy(start_tensor.Data(), start_tensor.Data() + size, input_starts.begin()); std::copy(ends_tensor.Data(), ends_tensor.Data() + size, input_ends.begin()); if (nullptr != axes_tensor) @@ -251,7 +278,7 @@ void SliceBase::FillVectorsFromInput(const Tensor& start_tensor, std::copy(steps_tensor->Data(), steps_tensor->Data() + size, input_steps.begin()); } - else if (start_tensor.IsDataType()) { + else if (int64_enabled && start_tensor.IsDataType()) { std::copy(start_tensor.Data(), start_tensor.Data() + size, input_starts.begin()); std::copy(ends_tensor.Data(), ends_tensor.Data() + size, input_ends.begin()); if (nullptr != axes_tensor) @@ -261,10 +288,13 @@ void SliceBase::FillVectorsFromInput(const Tensor& start_tensor, std::copy(steps_tensor->Data(), steps_tensor->Data() + size, input_steps.begin()); } - // should not reach this as no kernel is registered for this condition to be triggered - just an additional safety check else { - ORT_THROW("Data type for starts and ends inputs' need to be int32_t or int64_t, but instead got ", start_tensor.DataType()); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Data type for starts and ends inputs' is not supported in this build. Got ", + start_tensor.DataType()); } + + return Status::OK(); } template @@ -304,22 +334,40 @@ static Status SliceImpl(OpKernelContext* ctx, flattened_input_dims.back() = compute_metadata.p_flattened_output_dims_->back(); TensorShape input_shape(std::move(flattened_input_dims)); - auto input_iterator = SliceIterator(input_tensor, input_shape, compute_metadata.starts_, *compute_metadata.p_flattened_output_dims_, compute_metadata.steps_); + auto input_iterator = SliceIterator(input_tensor, input_shape, compute_metadata.starts_, + *compute_metadata.p_flattened_output_dims_, compute_metadata.steps_); create_output(input_iterator); } else { - auto input_iterator = SliceIterator(input_tensor, compute_metadata.starts_, compute_metadata.output_dims_, compute_metadata.steps_); + auto input_iterator = SliceIterator(input_tensor, compute_metadata.starts_, compute_metadata.output_dims_, + compute_metadata.steps_); create_output(input_iterator); } return Status::OK(); } +template +static inline bool CallSliceImplIfEnabled(OpKernelContext* ctx, + const Tensor& input_tensor, + SliceOp::PrepareForComputeMetadata& compute_metadata, + Status& status) { + constexpr bool enabled = utils::HasTypeWithSameSize(); + if (enabled) { + status = SliceImpl(ctx, input_tensor, compute_metadata); + } + + return enabled; +} + Status SliceBase::Compute(OpKernelContext* ctx) const { const auto* input_tensor_ptr = ctx->Input(0); - ORT_ENFORCE(input_tensor_ptr != nullptr, "Missing input tensor to be processed"); const auto& input_tensor = *input_tensor_ptr; const auto& input_dimensions = input_tensor.Shape().GetDims(); - if (input_dimensions.empty()) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Cannot slice scalars"); + + if (input_dimensions.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Cannot slice scalars"); + } + SliceOp::PrepareForComputeMetadata compute_metadata(input_dimensions); // Slice V10 & DynamicSlice @@ -328,8 +376,9 @@ Status SliceBase::Compute(OpKernelContext* ctx) const { std::vector input_ends; std::vector input_axes; std::vector input_steps; - FillVectorsFromInput(*ctx->Input(1), *ctx->Input(2), ctx->Input(3), - ctx->Input(4), input_starts, input_ends, input_axes, input_steps); + ORT_RETURN_IF_ERROR(FillVectorsFromInput(*ctx->Input(1), *ctx->Input(2), + ctx->Input(3), ctx->Input(4), + input_starts, input_ends, input_axes, input_steps)); ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps, compute_metadata)); } @@ -340,28 +389,38 @@ Status SliceBase::Compute(OpKernelContext* ctx) const { Status status = Status::OK(); + bool supported = false; if (input_tensor.IsDataTypeString()) { - status = SliceImpl(ctx, input_tensor, compute_metadata); + if (utils::HasType()) { + supported = true; + status = SliceImpl(ctx, input_tensor, compute_metadata); + } } else { const auto element_size = input_tensor.DataType()->Size(); - + // call SliceImpl switch (element_size) { case sizeof(uint32_t): - status = SliceImpl(ctx, input_tensor, compute_metadata); + supported = CallSliceImplIfEnabled(ctx, input_tensor, compute_metadata, status); break; case sizeof(uint64_t): - status = SliceImpl(ctx, input_tensor, compute_metadata); + supported = CallSliceImplIfEnabled(ctx, input_tensor, compute_metadata, status); break; case sizeof(uint16_t): - status = SliceImpl(ctx, input_tensor, compute_metadata); + supported = CallSliceImplIfEnabled(ctx, input_tensor, compute_metadata, status); break; case sizeof(uint8_t): - status = SliceImpl(ctx, input_tensor, compute_metadata); + supported = CallSliceImplIfEnabled(ctx, input_tensor, compute_metadata, status); break; default: - ORT_THROW("Unsupported input data type of ", input_tensor.DataType()); + // leave 'supported' as false + break; } } + + if (!supported) { + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported input data type of ", input_tensor.DataType()); + } + return status; } diff --git a/onnxruntime/core/providers/cpu/tensor/slice.h b/onnxruntime/core/providers/cpu/tensor/slice.h index 7febbf84d2..be69df1b38 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice.h +++ b/onnxruntime/core/providers/cpu/tensor/slice.h @@ -44,14 +44,14 @@ class SliceBase { SliceOp::PrepareForComputeMetadata& compute_metadata); // Slice V10 & DynamicSlice - static void FillVectorsFromInput(const Tensor& start_tensor, - const Tensor& ends_tensor, - const Tensor* axes_tensor, - const Tensor* steps_tensor, - std::vector& input_starts, - std::vector& input_ends, - std::vector& input_axes, - std::vector& input_steps); + static Status FillVectorsFromInput(const Tensor& start_tensor, + const Tensor& ends_tensor, + const Tensor* axes_tensor, + const Tensor* steps_tensor, + std::vector& input_starts, + std::vector& input_ends, + std::vector& input_axes, + std::vector& input_steps); protected: SliceBase(const OpKernelInfo& info, bool dynamic = false) diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.cc b/onnxruntime/core/providers/cpu/tensor/transpose.cc index 24770b235b..482dd019c7 100644 --- a/onnxruntime/core/providers/cpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/cpu/tensor/transpose.cc @@ -2,12 +2,30 @@ // Licensed under the MIT License. #include "core/providers/cpu/tensor/transpose.h" + #include "core/framework/utils.h" #include "core/mlas/inc/mlas.h" +#include "core/providers/op_kernel_type_control.h" +#include "core/providers/op_kernel_type_control_utils.h" #include "utils.h" namespace onnxruntime { +namespace op_kernel_type_control { +// we're using one set of types for all opsets +ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS( + kCpuExecutionProvider, kOnnxDomain, Transpose, Input, 0, + ORT_OP_KERNEL_TYPE_CTRL_ALL_TENSOR_DATA_TYPES); +} // namespace op_kernel_type_control + +namespace { +// reduce the supported types with any global or op specific lists +using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain, + Transpose, Input, 0); + +const std::vector dataTypeConstraints = BuildKernelDefConstraintsFunctorFromTypeList{}(); +} // namespace + /* A permutation [a,b,c,...] indicates that - The 0-th dimension of the output corresponds to the a-th dimension of input - The 1-st dimension of the output corresponds to the b-th dimension of input @@ -152,42 +170,54 @@ inline void CopyPrim(uint8_t* target, const uint8_t* source) { // The function does not check num_axes > 0 but this is expected. template -static void TypedDoTransposeEltWise(int64_t num_axes, const std::vector& target_dims, size_t num_blocks, +static bool TypedDoTransposeEltWise(int64_t num_axes, const std::vector& target_dims, size_t num_blocks, const std::vector& stride, const uint8_t* source, uint8_t* target) { - MultiIndex mindex; - IncrementIndexAndComputeOffsetSetup(mindex, num_axes, target_dims, stride, sizeof(T)); + constexpr bool enabled = utils::HasTypeWithSameSize(); - const uint8_t* local_source = source; - uint8_t* target_end = target + sizeof(T) * num_blocks; - for (; target != target_end; target += sizeof(T)) { - ORT_ENFORCE((local_source >= source) && (local_source < source + sizeof(T) * num_blocks)); - CopyPrim(target, local_source); - IncrementIndexAndComputeOffset(mindex, local_source); + if (enabled) { + MultiIndex mindex; + IncrementIndexAndComputeOffsetSetup(mindex, num_axes, target_dims, stride, sizeof(T)); + + const uint8_t* local_source = source; + uint8_t* target_end = target + sizeof(T) * num_blocks; + for (; target != target_end; target += sizeof(T)) { + ORT_ENFORCE((local_source >= source) && (local_source < source + sizeof(T) * num_blocks)); + CopyPrim(target, local_source); + IncrementIndexAndComputeOffset(mindex, local_source); + } } + + return enabled; } // DoTransposeEltWise: specialization of DoTranspose for the num_elts_in_block=1 case. // copies source tensor to target, transposing elements. // The stride vector indicates the transposition. -void DoTransposeEltWise(int64_t num_axes, const std::vector& target_dims, size_t num_blocks, - const std::vector& stride, const uint8_t* source, uint8_t* target, - size_t element_size) { +Status DoTransposeEltWise(int64_t num_axes, const std::vector& target_dims, size_t num_blocks, + const std::vector& stride, const uint8_t* source, uint8_t* target, + size_t element_size) { + bool enabled = false; switch (element_size) { case sizeof(uint64_t): - TypedDoTransposeEltWise(num_axes, target_dims, num_blocks, stride, source, target); + enabled = TypedDoTransposeEltWise(num_axes, target_dims, num_blocks, stride, source, target); break; case sizeof(uint32_t): - TypedDoTransposeEltWise(num_axes, target_dims, num_blocks, stride, source, target); + enabled = TypedDoTransposeEltWise(num_axes, target_dims, num_blocks, stride, source, target); break; case sizeof(uint16_t): - TypedDoTransposeEltWise(num_axes, target_dims, num_blocks, stride, source, target); + enabled = TypedDoTransposeEltWise(num_axes, target_dims, num_blocks, stride, source, target); break; case sizeof(uint8_t): - TypedDoTransposeEltWise(num_axes, target_dims, num_blocks, stride, source, target); + enabled = TypedDoTransposeEltWise(num_axes, target_dims, num_blocks, stride, source, target); break; default: - assert(false); + // leave enabled as false + break; } + + return enabled ? Status::OK() + : ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Transpose of element size not supported in this build. Size=", + element_size); } static void DoTransposeEltWise(int64_t num_axes, const std::vector& target_dims, size_t num_blocks, @@ -243,17 +273,25 @@ static Status DoUntypedTranspose(const std::vector& permutations, const } } + Status status = Status::OK(); + if (is_string_type) { - const auto* input_data = input.template Data(); - auto* output_data = output.template MutableData(); - if (1 == prefix_blocksize) { - DoTransposeSingleBlock(suffix_blocksize, input_data, output_data); - } else if (1 == suffix_blocksize) { - DoTransposeEltWise(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, stride, - input_data, output_data); + constexpr bool string_enabled = utils::HasType(); + + if (string_enabled) { + const auto* input_data = input.template Data(); + auto* output_data = output.template MutableData(); + if (1 == prefix_blocksize) { + DoTransposeSingleBlock(suffix_blocksize, input_data, output_data); + } else if (1 == suffix_blocksize) { + DoTransposeEltWise(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, stride, + input_data, output_data); + } else { + DoTransposeImpl(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, suffix_blocksize, stride, + input_data, output_data); + } } else { - DoTransposeImpl(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, suffix_blocksize, stride, - input_data, output_data); + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Transpose of std::string is not supported in this build."); } } else { const auto* input_data = reinterpret_cast(input.DataRaw()); @@ -261,15 +299,16 @@ static Status DoUntypedTranspose(const std::vector& permutations, const if (1 == prefix_blocksize) { DoTransposeSingleBlock(suffix_blocksize, input_data, output_data, element_size); } else if (1 == suffix_blocksize) { - DoTransposeEltWise(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, stride, - input_data, output_data, element_size); + // this may return a failed status if the data size is not supported in this build + status = DoTransposeEltWise(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, stride, + input_data, output_data, element_size); } else { DoTransposeImpl(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, suffix_blocksize, stride, input_data, output_data, element_size); } } - return Status::OK(); + return status; } /* @@ -686,13 +725,13 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Transpose, 1, 12, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()), + KernelDefBuilder().TypeConstraint("T", dataTypeConstraints), Transpose); ONNX_CPU_OPERATOR_KERNEL( Transpose, 13, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()), + KernelDefBuilder().TypeConstraint("T", dataTypeConstraints), Transpose); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.h b/onnxruntime/core/providers/cpu/tensor/transpose.h index 341975d475..c003b2e8f2 100644 --- a/onnxruntime/core/providers/cpu/tensor/transpose.h +++ b/onnxruntime/core/providers/cpu/tensor/transpose.h @@ -16,9 +16,10 @@ namespace onnxruntime { */ bool IsTransposeReshape(const std::vector& perm, const std::vector& input_dims); -void DoTransposeEltWise(int64_t num_axes, const std::vector& target_dims, size_t num_blocks, - const std::vector& stride, const uint8_t* source, uint8_t* target, - size_t element_size); +// Public function for element-wise transpose, primarily to unit test any out of bounds access +Status DoTransposeEltWise(int64_t num_axes, const std::vector& target_dims, size_t num_blocks, + const std::vector& stride, const uint8_t* source, uint8_t* target, + size_t element_size); class TransposeBase { public: diff --git a/onnxruntime/core/providers/op_kernel_type_control.h b/onnxruntime/core/providers/op_kernel_type_control.h index c61c0381d4..d70ca7600b 100644 --- a/onnxruntime/core/providers/op_kernel_type_control.h +++ b/onnxruntime/core/providers/op_kernel_type_control.h @@ -62,14 +62,15 @@ struct GlobalAllowed {}; } // namespace tags // optionally holds a list of types associated with a tag class -// if types are defined, the data member 'types' should contain them in a type list -// otherwise, if no types are defined (distinct from an empty list of types), there should be no data member 'types' +// if types are defined, the type alias member called 'types' should contain them in a type list +// (e.g. using something like std::tuple or a boost::mp11::mp_list) +// otherwise, if no types are defined (distinct from an empty list of types), there should be no 'types' type alias // see the tags in onnxruntime::op_kernel_type_control::tags for intended uses template struct TypesHolder {}; /** - * Provides a type list of enabled types via the 'types' data member. + * Provides a type list of enabled types via the 'types' type alias member. * Enabled types are the set intersection of supported and allowed types. * * @tparam SupportedTypesHolder A 'TypesHolder' with a list of supported types. @@ -84,15 +85,15 @@ struct EnabledTypes { template using GetTypesMember = typename T::types; - // checks whether T has data member 'types' + // checks whether T has a type alias member called 'types' template using HasTypesMember = boost::mp11::mp_valid; static_assert(HasTypesMember::value, - "SupportedTypesHolder must have a 'types' data member."); + "SupportedTypesHolder must have a type alias called 'types'."); // the allowed type lists to consider - // for each element of AllowedTypesHolders, get and include a 'types' data member if present + // for each element of AllowedTypesHolders, get and include the 'types' type alias member if present using AllowedTypesMembers = boost::mp11::mp_transform< GetTypesMember, @@ -105,17 +106,10 @@ struct EnabledTypes { boost::mp11::mp_push_front>; static_assert(boost::mp11::mp_all_of::value, - "All 'types' data members must be type lists."); - - // converts type list L into a type set (type list with unique elements) - template - using MakeSet = - boost::mp11::mp_apply< - boost::mp11::mp_set_push_back, - boost::mp11::mp_append>, L>>; + "All 'types' type aliases must be type lists."); // type lists converted to type sets - using TypeSetsToConsider = boost::mp11::mp_transform; + using TypeSetsToConsider = boost::mp11::mp_transform; public: using types = boost::mp11::mp_apply; @@ -237,37 +231,6 @@ struct EnabledTypes { ::onnxruntime::op_kernel_type_control::kAllOpSets, \ ArgDirection, ArgIndex) -/** - * std::tuple type with the enabled types for a given Op kernel argument. - * - * @param OpProvider The Op provider. - * @param OpDomain The Op domain. - * @param OpName The Op name. - * @param OpSet The opset to use for the supported types list. - * @param ArgDirection Direction of the given Op kernel argument - Input or Output. - * @param ArgIndex Index of the given Op kernel argument. - */ -#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE( \ - OpProvider, OpDomain, OpName, OpSet, ArgDirection, ArgIndex) \ - ::boost::mp11::mp_rename< \ - ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( \ - OpProvider, OpDomain, OpName, OpSet, ArgDirection, ArgIndex, SupportedTypeList), \ - std::tuple> - -/** - * std::tuple type with the enabled types for a given Op kernel argument that is valid for all opsets. - * - * @param OpProvider The Op provider. - * @param OpDomain The Op domain. - * @param OpName The Op name. - * @param ArgDirection Direction of the given Op kernel argument - Input or Output. - * @param ArgIndex Index of the given Op kernel argument. - */ -#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE_ALL_OPSETS( \ - OpProvider, OpDomain, OpName, ArgDirection, ArgIndex) \ - ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE(OpProvider, OpDomain, OpName, \ - ::onnxruntime::op_kernel_type_control::kAllOpSets, \ - ArgDirection, ArgIndex) /** * Usage example: * @@ -277,7 +240,7 @@ struct EnabledTypes { * namespace op_kernel_type_control { * // specify supported types, i.e., the full set of types that can be enabled * ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES( - * MyProvider, DomainContainingMyOp, MyOp, Input, 0, + * MyProvider, DomainContainingMyOp, MyOp, OpSet, Input, 0, * int, float, double); * } // namespace op_kernel_type_control * } // namespace onnxruntime diff --git a/onnxruntime/core/providers/op_kernel_type_control_utils.h b/onnxruntime/core/providers/op_kernel_type_control_utils.h new file mode 100644 index 0000000000..3e1354d43a --- /dev/null +++ b/onnxruntime/core/providers/op_kernel_type_control_utils.h @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "boost/mp11.hpp" + +#include "core/framework/data_types.h" + +namespace onnxruntime { +namespace utils { +/** +* Check if the set of types contains the specified type. +*/ +template +constexpr bool HasType() { + static_assert(boost::mp11::mp_is_set::value, "TypeSet must be a type set."); + + return boost::mp11::mp_set_contains::value; +} + +template +using SizeOfT = boost::mp11::mp_size_t; + +/** +* Check if the set of types contains a type with the same size as T. +* +* @remarks e.g. will return true if T is int32_t and the list contains any 4 byte type (i.e. sizeof(int32_t)) +* such as int32_t, uint32_t or float. +*/ +template +constexpr bool HasTypeWithSameSize() { + static_assert(boost::mp11::mp_is_set::value, "TypeSet must be a type set."); + + using EnabledTypeSizes = boost::mp11::mp_unique>; + return boost::mp11::mp_set_contains>::value; +} + +} // namespace utils +} // namespace onnxruntime + +/** Data types that are used in DataTypeImpl::AllTensorTypes() +*/ +#define ORT_OP_KERNEL_TYPE_CTRL_ALL_TENSOR_DATA_TYPES \ + bool, \ + float, double, \ + uint8_t, uint16_t, uint32_t, uint64_t, \ + int8_t, int16_t, int32_t, int64_t, \ + MLFloat16, BFloat16, \ + std::string diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc index ae2fed48f9..f929c781e2 100644 --- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc @@ -5,6 +5,7 @@ #include "test/providers/provider_test_utils.h" #include "test/providers/compare_provider_test_utils.h" #include "core/providers/cpu/tensor/transpose.h" +#include "test/util/include/asserts.h" namespace onnxruntime { namespace test { @@ -560,9 +561,9 @@ TEST(TransposeOpTest, DoTransposeEltWise) { 13.0f, 15.0f, 14.0f, 16.0f, 17.0f, 17.0f}; - DoTransposeEltWise(input_shape.size(), input_shape, 16, - stride, (uint8_t*)input_vals_end.data(), (uint8_t*)target.data(), - sizeof(float)); + ASSERT_STATUS_OK(DoTransposeEltWise(input_shape.size(), input_shape, 16, + stride, (uint8_t*)input_vals_end.data(), (uint8_t*)target.data(), + sizeof(float))); for (size_t i = 0; i < input_vals_end.size(); ++i) { ASSERT_TRUE(target[i] == expected_vals3[i]); } diff --git a/onnxruntime/test/testdata/mnist.readme.txt b/onnxruntime/test/testdata/mnist.readme.txt index 8d3bd63e37..cc370ab542 100644 --- a/onnxruntime/test/testdata/mnist.readme.txt +++ b/onnxruntime/test/testdata/mnist.readme.txt @@ -1,4 +1,4 @@ -The mnist model is used in a multiple tests for minimal/mobile builds in both ONNX and ORT formats. +The mnist model is used in multiple tests for minimal/mobile builds in both ONNX and ORT formats. We also save both ONNX and ORT format versions of the model with level 1 (aka 'basic') optimizations applied. - mnist.level1_opt.onnx makes sure the required operators for this model are automatically included in diff --git a/tools/python/util/ort_format_model/operator_type_usage_processors.py b/tools/python/util/ort_format_model/operator_type_usage_processors.py index 3c179e3355..8367900532 100644 --- a/tools/python/util/ort_format_model/operator_type_usage_processors.py +++ b/tools/python/util/ort_format_model/operator_type_usage_processors.py @@ -109,13 +109,18 @@ class DefaultTypeUsageProcessor(TypeUsageProcessor): def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict): for i in self._input_types.keys(): if i >= node.InputsLength(): - raise RuntimeError('Node has {} inputs. Tracker for {} incorrectly configured as it requires {}.' - .format(node.InputsLength(), self.name, i)) - - type_str = value_name_to_typestr(node.Inputs(i), value_name_to_typeinfo) - self._input_types[i].add(type_str) + # Some operators have fewer inputs in earlier versions where data that was as an attribute + # become an input in later versions to allow it to be dynamically provided. Allow for that. + # e.g. Slice-1 had attributes for the indices, and Slice-10 moved those to be inputs + # raise RuntimeError('Node has {} outputs. Tracker for {} incorrectly configured as it requires {}.' + # .format(node.OutputsLength(), self.name, o)) + pass + else: + type_str = value_name_to_typestr(node.Inputs(i), value_name_to_typeinfo) + self._input_types[i].add(type_str) for o in self._output_types.keys(): + # Don't know of any ops where the number of outputs changed across versions, so require a valid length if o >= node.OutputsLength(): raise RuntimeError('Node has {} outputs. Tracker for {} incorrectly configured as it requires {}.' .format(node.OutputsLength(), self.name, o)) @@ -127,7 +132,7 @@ class DefaultTypeUsageProcessor(TypeUsageProcessor): if 0 not in self._input_types.keys(): # currently all standard typed registrations are for input 0. # custom registrations can be handled by operator specific processors (e.g. OneHotProcessor below). - raise RuntimeError('Expected typed registration to use type from input 0.') + raise RuntimeError('Expected typed registration to use type from input 0. Node:{}'.format(self.name)) return type_in_registration in self._input_types[0] @@ -254,8 +259,8 @@ def _create_operator_type_usage_processors(): # - some known large kernels # # Ops we are ignoring currently so as not to produce meaningless/unused output: - # - Implementation is not type specific: - # If, Loop, Reshape, Scan, Shape, Squeeze, Unsqueeze + # - Implementation is type agnostic: + # DynamicQuantizeMatMul, If, Loop, Reshape, Scan, Shape, Squeeze, Unsqueeze # - Only one type supported in the ORT implementation: # FusedConv, FusedGemm, FusedMatMul, TransposeMatMul # - Implementation does not have any significant type specific code: @@ -264,7 +269,7 @@ def _create_operator_type_usage_processors(): 'DequantizeLinear', 'Div', 'Equal', 'Exp', 'Expand', 'Gemm', 'Greater', 'Less', 'MatMul', 'Max', 'Min', 'Mul', 'NonMaxSuppression', 'NonZero', 'Pad', 'Range', 'Relu', 'Resize', - 'Sigmoid', 'Slice', 'Softmax', 'Split', 'Sub', 'Tile', 'TopK', 'Transpose'] + 'Sigmoid', 'Softmax', 'Split', 'Sub', 'Tile', 'TopK', 'Transpose'] internal_ops = ['QLinearAdd', 'QLinearMul'] @@ -286,12 +291,15 @@ def _create_operator_type_usage_processors(): # # Operators that require custom handling # - add(DefaultTypeUsageProcessor('ai.onnx', 'Cast', inputs=[0], outputs=[0])) # track input0 and output0 + + # Cast switches on types of input 0 and output 0 + add(DefaultTypeUsageProcessor('ai.onnx', 'Cast', inputs=[0], outputs=[0])) # Operators that switch on the type of input 0 and 1 add(DefaultTypeUsageProcessor('ai.onnx', 'Gather', inputs=[0, 1])) add(DefaultTypeUsageProcessor('ai.onnx', 'GatherElements', inputs=[0, 1])) add(DefaultTypeUsageProcessor('ai.onnx', 'Pow', inputs=[0, 1])) + add(DefaultTypeUsageProcessor('ai.onnx', 'Slice', inputs=[0, 1])) # Operators that switch on output type add(DefaultTypeUsageProcessor('ai.onnx', 'ConstantOfShape', inputs=[], outputs=[0]))