Op kernel type reduction infrastructure. (#6466)

Add infrastructure to support type reduction in Op kernel implementations.
Update Cast and IsInf CPU kernels to use it.
This commit is contained in:
Edward Chen 2021-01-28 07:27:19 -08:00 committed by GitHub
parent 91b19b8364
commit d850fa63bf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 760 additions and 390 deletions

3
.gitmodules vendored
View file

@ -59,6 +59,9 @@
[submodule "cmake/external/optional-lite"]
path = cmake/external/optional-lite
url = https://github.com/martinmoene/optional-lite.git
[submodule "cmake/external/mp11"]
path = cmake/external/mp11
url = https://github.com/boostorg/mp11.git
[submodule "cmake/external/coremltools"]
path = cmake/external/coremltools
url = https://github.com/apple/coremltools.git

View file

@ -218,6 +218,16 @@
"comments": "git submodule at cmake/external/mimalloc"
}
},
{
"component": {
"type": "git",
"git": {
"commitHash": "21cace4e574180ba64d9307a5e4ea9e5e94d3e8d",
"repositoryUrl": "https://github.com/boostorg/mp11.git"
},
"comments": "git submodule at cmake/external/mp11"
}
},
{
"component": {
"type": "git",

View file

@ -662,10 +662,6 @@ set(ONNXRUNTIME_INCLUDE_DIR ${REPO_ROOT}/include/onnxruntime)
add_subdirectory(external/date EXCLUDE_FROM_ALL)
if(onnxruntime_PREFER_SYSTEM_LIB)
find_package(re2)
endif()
set(SAFEINT_INCLUDE_DIR ${REPO_ROOT}/cmake/external/SafeInt)
add_library(safeint_interface INTERFACE)
target_include_directories(safeint_interface INTERFACE ${SAFEINT_INCLUDE_DIR})
@ -675,6 +671,11 @@ if(onnxruntime_DISABLE_EXCEPTIONS)
add_compile_definitions(optional_CONFIG_NO_EXCEPTIONS=1)
endif()
add_subdirectory(external/mp11 EXCLUDE_FROM_ALL)
if(onnxruntime_PREFER_SYSTEM_LIB)
find_package(re2)
endif()
if(NOT TARGET re2::re2)
add_subdirectory(external/re2 EXCLUDE_FROM_ALL)
set_target_properties(re2 PROPERTIES FOLDER "External/re2")

1
cmake/external/mp11 vendored Submodule

@ -0,0 +1 @@
Subproject commit 21cace4e574180ba64d9307a5e4ea9e5e94d3e8d

View file

@ -105,6 +105,8 @@ target_include_directories(onnxruntime_common
$<TARGET_PROPERTY:safeint_interface,INTERFACE_INCLUDE_DIRECTORIES>
${OPTIONAL_LITE_INCLUDE_DIR})
target_link_libraries(onnxruntime_common Boost::mp11)
if(NOT WIN32)
target_include_directories(onnxruntime_common PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/external/nsync/public")
endif()

View file

@ -53,6 +53,7 @@ file(GLOB onnxruntime_cpu_featurizers_cc_srcs CONFIGURE_DEPENDS
file(GLOB onnxruntime_providers_common_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/*.cc"
"${ONNXRUNTIME_ROOT}/core/providers/op_kernel_type_control_overrides.inc"
)
if(onnxruntime_USE_NUPHAR)

View file

@ -397,6 +397,7 @@ target_include_directories(winml_lib_image PRIVATE ${REPO_ROOT}/cmake/external/p
target_include_directories(winml_lib_image PRIVATE ${ONNXRUNTIME_INCLUDE_DIR}/core/platform/windows)
target_include_directories(winml_lib_image PRIVATE ${REPO_ROOT}/cmake/external/flatbuffers/include)
target_include_directories(winml_lib_image PRIVATE ${REPO_ROOT}/cmake/external/optional-lite/include)
target_include_directories(winml_lib_image PRIVATE ${REPO_ROOT}/cmake/external/mp11/include)
# Properties
set_target_properties(winml_lib_image
@ -512,6 +513,7 @@ target_include_directories(winml_lib_api PRIVATE ${REPO_ROOT}/cmake/external/gsl
target_include_directories(winml_lib_api PRIVATE ${REPO_ROOT}/cmake/external/SafeInt)
target_include_directories(winml_lib_api PRIVATE ${REPO_ROOT}/cmake/external/flatbuffers/include)
target_include_directories(winml_lib_api PRIVATE ${REPO_ROOT}/cmake/external/optional-lite/include)
target_include_directories(winml_lib_api PRIVATE ${REPO_ROOT}/cmake/external/mp11/include)
# Properties
set_target_properties(winml_lib_api
@ -594,6 +596,7 @@ target_include_directories(winml_lib_api_experimental PRIVATE ${REPO_ROOT}/cmake
target_include_directories(winml_lib_api_experimental PRIVATE ${REPO_ROOT}/cmake/external/SafeInt)
target_include_directories(winml_lib_api_experimental PRIVATE ${REPO_ROOT}/cmake/external/flatbuffers/include)
target_include_directories(winml_lib_api_experimental PRIVATE ${REPO_ROOT}/cmake/external/optional-lite/include)
target_include_directories(winml_lib_api_experimental PRIVATE ${REPO_ROOT}/cmake/external/mp11/include)
# Properties
set_target_properties(winml_lib_api_experimental
@ -748,6 +751,7 @@ target_include_directories(winml_dll PRIVATE ${REPO_ROOT}/cmake/external/eigen)
target_include_directories(winml_dll PRIVATE ${REPO_ROOT}/cmake/external/SafeInt)
target_include_directories(winml_dll PRIVATE ${REPO_ROOT}/cmake/external/flatbuffers/include)
target_include_directories(winml_dll PRIVATE ${REPO_ROOT}/cmake/external/optional-lite/include)
target_include_directories(winml_dll PRIVATE ${REPO_ROOT}/cmake/external/mp11/include)
# Properties
set_target_properties(winml_dll

View file

@ -163,7 +163,7 @@ function (get_winml_test_model_src
"${winml_test_src_path}/model/*.cpp")
set(${output_winml_test_model_src} ${winml_test_model_src} PARENT_SCOPE)
set(${winml_test_model_libs} onnx_test_data_proto onnx_test_runner_common onnxruntime_common onnxruntime_mlas
onnxruntime_graph onnxruntime_test_utils onnxruntime_framework onnxruntime_flatbuffers PARENT_SCOPE)
onnxruntime_graph onnxruntime_test_utils onnxruntime_framework onnxruntime_util onnxruntime_flatbuffers PARENT_SCOPE)
endfunction()
file(GLOB winml_test_common_src CONFIGURE_DEPENDS

View file

@ -52,11 +52,12 @@ class NonTensorTypeBase;
class PrimitiveDataTypeBase;
// MLFloat16
union MLFloat16 {
struct MLFloat16 {
uint16_t val;
explicit MLFloat16(uint16_t x) : val(x) {}
MLFloat16() : val(0) {}
explicit MLFloat16(uint16_t x) : val(x) {}
explicit MLFloat16(float f);
// Taken from https://stackoverflow.com/a/60047308/12627730
float AsFloat(uint32_t x) const {

View file

@ -3,12 +3,16 @@
#pragma once
#include <assert.h>
#include <stdint.h>
#include <array>
#include <cassert>
#include <cstdint>
#include <string>
#include <vector>
#include "boost/mp11.hpp"
#include "core/common/common.h"
#include "core/common/type_list.h"
#include "core/framework/data_types.h"
#include "core/graph/onnx_protobuf.h"
@ -341,25 +345,54 @@ class MLTypeCallDispatcherRet {
}
};
// Version of the MLTypeDispatcher that has an input type which is passed through ('carried')
// as the first type parameter in the call to Fn when dispatching.
template <typename TCarried, template <typename, typename> class Fn, typename... Types>
class MLTypeCallDispatcherWithCarriedType {
// Version of MLTypeCallDispatcher that takes supported types as class-level template parameters.
// This enables easier use with type list representations of the supported types.
// The invocation-related template parameters like Fn move to the individual Invoke() methods.
// TODO consolidate this with the other MLTypeCallDispatcher classes
// can add additional methods to cover their usages, but need to update call sites
template <typename... Types>
class MLTypeCallDispatcher2 {
static_assert(boost::mp11::mp_is_set<TypeList<Types...>>::value,
"MLTypeCallDispatcher requires a set of unique types.");
int32_t dt_type_;
public:
explicit MLTypeCallDispatcherWithCarriedType(int32_t dt_type) noexcept : dt_type_(dt_type) {}
explicit MLTypeCallDispatcher2(int32_t dt_type) noexcept : dt_type_(dt_type) {}
template <typename... Args>
template <template <typename> class Fn, typename... Args>
void Invoke(Args&&... args) const {
mltype_dispatcher_internal::CallableDispatchableHelper helper(dt_type_);
int results[] = {0, helper.template Invoke<Types>(Fn<TCarried, Types>(), std::forward<Args>(args)...)...};
ORT_UNUSED_PARAMETER(results);
ORT_ENFORCE(helper.called_ < 2, "Check for duplicate types in MLTypeCallDispatcher");
static_cast<void>(std::array<int, sizeof...(Types)>{
helper.template Invoke<Types>(Fn<Types>(), std::forward<Args>(args)...)...});
// avoid "unused parameter" warning for the case where Types is empty
static_cast<void>(std::array<int, sizeof...(Args)>{(ORT_UNUSED_PARAMETER(args), 0)...});
ORT_ENFORCE(helper.called_ == 1, "Unsupported data type: ", dt_type_);
}
template <template <typename...> class Fn, typename LeadingTemplateArgTypeList, typename... Args>
void InvokeWithLeadingTemplateArgs(Args&&... args) const {
mltype_dispatcher_internal::CallableDispatchableHelper helper(dt_type_);
static_cast<void>(std::array<int, sizeof...(Types)>{
helper.template Invoke<Types>(
boost::mp11::mp_apply<Fn, boost::mp11::mp_push_back<LeadingTemplateArgTypeList, Types>>(),
std::forward<Args>(args)...)...});
// avoid "unused parameter" warning for the case where Types is empty
static_cast<void>(std::array<int, sizeof...(Args)>{(ORT_UNUSED_PARAMETER(args), 0)...});
ORT_ENFORCE(helper.called_ == 1, "Unsupported data type: ", dt_type_);
}
};
// the type MLTypeCallDispatcher2<T...> given a type list L<T...>
template <typename L>
using MLTypeCallDispatcherFromTypeList = boost::mp11::mp_apply<MLTypeCallDispatcher2, L>;
namespace data_types_internal {
enum class ContainerType : uint16_t {

View file

@ -5,6 +5,8 @@
#include <functional>
#include "boost/mp11.hpp"
#include "core/common/exceptions.h"
#include "core/common/logging/logging.h"
#include "core/common/status.h"
@ -481,4 +483,17 @@ inline std::vector<MLDataType> BuildKernelDefConstraints() {
return {DataTypeImpl::GetTensorType<T>(), DataTypeImpl::GetTensorType<Types>()...};
}
// functor that calls BuildKernelDefConstraints()
template <typename... Types>
struct BuildKernelDefConstraintsFunctor {
std::vector<MLDataType> operator()() const {
return BuildKernelDefConstraints<Types...>();
}
};
// the type BuildKernelDefConstraintsFunctor<T...> given a type list L<T...>
template <typename L>
using BuildKernelDefConstraintsFunctorFromTypeList =
boost::mp11::mp_apply<BuildKernelDefConstraintsFunctor, L>;
} // namespace onnxruntime

View file

@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace onnxruntime {
// this type represents a compile-time list of types
template <typename... T>
struct TypeList {};
}

View file

@ -7,6 +7,7 @@
#include "core/framework/sparse_tensor.h"
#include "core/framework/data_types_internal.h"
#include "core/graph/onnx_protobuf.h"
#include "core/util/math.h"
#ifdef __GNUC__
#pragma GCC diagnostic push
@ -21,6 +22,9 @@
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
MLFloat16::MLFloat16(float f) : val{math::floatToHalf(f)} {}
// Return the MLDataType used for a generic Tensor
template <>
MLDataType DataTypeImpl::GetType<Tensor>() {

View file

@ -1,193 +1,231 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <cstddef>
#include <iomanip>
#include <sstream>
#include "boost/mp11.hpp"
#include "gsl/gsl"
#include "core/common/common.h"
#include "core/common/type_list.h"
#include "core/framework/data_types.h"
#include "core/framework/data_types_internal.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/util/math.h"
#include "core/providers/op_kernel_type_control.h"
#include "core/util/math_cpuonly.h"
#include "Eigen/src/Core/arch/Default/Half.h"
#if defined(_M_AMD64)
#include "core/mlas/inc/mlas.h"
#endif
// FUTURE:
// Float16 and String have expensive special cased handling. Enable by default, but provide an easy way to disable
// in the future if needed. Disabling both saves ~50KB.
#define CAST_FLOAT16_ENABLED
#define CAST_STRING_ENABLED
using namespace ONNX_NAMESPACE;
using namespace boost::mp11;
namespace onnxruntime {
namespace op_kernel_type_control {
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0,
bool,
float, double,
uint8_t, uint16_t, uint32_t, uint64_t,
int8_t, int16_t, int32_t, int64_t,
MLFloat16, BFloat16,
std::string);
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
kCpuExecutionProvider, kOnnxDomain, Cast, Output, 0,
bool,
float, double,
uint8_t, uint16_t, uint32_t, uint64_t,
int8_t, int16_t, int32_t, int64_t,
MLFloat16, BFloat16,
std::string);
} // namespace op_kernel_type_control
namespace {
template <typename SrcType, typename DstType>
inline void CastData(const Tensor& in, Tensor& out, const TensorShape& shape) {
ptrdiff_t shape_size = gsl::narrow<ptrdiff_t>(shape.Size());
auto in_vector = ConstEigenVectorMap<SrcType>(in.Data<SrcType>(), shape_size);
auto output_vector = EigenVectorMap<DstType>(out.MutableData<DstType>(), shape_size);
output_vector = in_vector.template cast<DstType>();
}
#ifdef CAST_FLOAT16_ENABLED
template <>
inline void CastData<float, MLFloat16>(const Tensor& in, Tensor& out, const TensorShape& shape) {
auto out_data = out.MutableData<MLFloat16>();
ptrdiff_t shape_size = gsl::narrow<ptrdiff_t>(shape.Size());
auto in_vector = ConstEigenVectorMap<float>(in.Data<float>(), shape_size);
auto output_vector = EigenVectorMap<Eigen::half>(static_cast<Eigen::half*>(static_cast<void*>(out_data)),
shape_size);
output_vector = in_vector.template cast<Eigen::half>();
}
using EnabledSrcTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0);
using EnabledDstTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Cast, Output, 0);
template <>
inline void CastData<MLFloat16, float>(const Tensor& in, Tensor& out, const TensorShape& shape) {
auto out_data = out.MutableData<float>();
auto in_data = in.Data<MLFloat16>();
ptrdiff_t shape_size = gsl::narrow<ptrdiff_t>(shape.Size());
#if defined(_M_AMD64)
MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size);
#else
auto in_vector = ConstEigenVectorMap<Eigen::half>(static_cast<const Eigen::half*>(static_cast<const void*>(in_data)),
shape_size);
auto output_vector = EigenVectorMap<float>(out_data, shape_size);
output_vector = in_vector.template cast<float>();
#endif
}
using IndirectCastTypes = TypeList<MLFloat16, BFloat16>;
template <>
inline void CastData<float, BFloat16>(const Tensor& in, Tensor& out, const TensorShape& shape) {
auto out_data = out.template MutableData<BFloat16>();
ptrdiff_t shape_size = gsl::narrow<ptrdiff_t>(shape.Size());
auto in_vector = ConstEigenVectorMap<float>(in.template Data<float>(), shape_size);
auto output_vector = EigenVectorMap<BFloat16>(out_data, shape_size);
output_vector = in_vector.template cast<BFloat16>();
}
template <typename Type>
using IsDirectCastType = mp_not<mp_contains<IndirectCastTypes, Type>>;
template <>
inline void CastData<BFloat16, float>(const Tensor& in, Tensor& out, const TensorShape& shape) {
auto out_data = out.template MutableData<float>();
auto in_data = in.template Data<BFloat16>();
ptrdiff_t shape_size = gsl::narrow<ptrdiff_t>(shape.Size());
auto in_vector = ConstEigenVectorMap<BFloat16>(in_data, shape_size);
auto output_vector = EigenVectorMap<float>(out_data, shape_size);
output_vector = in_vector.unaryExpr([](BFloat16 val) { return val.ToFloat(); });
}
#endif
template <typename... Types>
using AreAllDirectCastTypes = mp_all<IsDirectCastType<Types>...>;
#ifdef CAST_STRING_ENABLED
// string cast helpers
// handle floating point input separately
template <typename SrcType>
typename std::enable_if<std::is_floating_point<SrcType>::value, void>::type
CastToStringData(const Tensor& in, Tensor& out, const TensorShape& shape) {
const int64_t len = shape.Size();
const auto input_data = in.DataAsSpan<SrcType>();
auto output_data = out.MutableDataAsSpan<std::string>();
for (int i = 0; i < len; ++i) {
if (std::isnan(input_data[i])) {
output_data[i] = "NaN";
} else if (std::isinf(input_data[i])) {
if (input_data[i] < std::numeric_limits<SrcType>::lowest()) {
output_data[i] = "-INF";
} else {
output_data[i] = "INF";
}
CastToString(const SrcType& input, std::string& output) {
if (std::isnan(input)) {
output = "NaN";
} else if (std::isinf(input)) {
if (input < std::numeric_limits<SrcType>::lowest()) {
output = "-INF";
} else {
// setprecision to 8 to match numpy default behavior
std::ostringstream convert;
convert << std::setprecision(8) << input_data[i];
output_data[i] = convert.str();
output = "INF";
}
} else {
// setprecision to 8 to match numpy default behavior
std::ostringstream convert;
convert << std::setprecision(8) << input;
output = convert.str();
}
}
template <typename SrcType>
typename std::enable_if<!std::is_floating_point<SrcType>::value, void>::type
CastToStringData(const Tensor& in, Tensor& out, const TensorShape& shape) {
const int64_t len = shape.Size();
const auto input_data = in.DataAsSpan<SrcType>();
auto output_data = out.MutableDataAsSpan<std::string>();
for (int i = 0; i < len; ++i) {
std::ostringstream convert;
convert << input_data[i];
output_data[i] = convert.str();
}
CastToString(const SrcType& input, std::string& output) {
std::ostringstream convert;
convert << input;
output = convert.str();
}
template <typename DstType>
void CastFromStringData(const Tensor& in, Tensor& out, const TensorShape& shape) {
const int64_t len = shape.Size();
if (std::is_same<DstType, float>::value) {
auto* mutable_data = out.MutableData<float>();
for (int i = 0; i < len; ++i) {
mutable_data[i] = std::stof(in.Data<std::string>()[i]);
}
} else if (std::is_same<DstType, double>::value) {
auto* mutable_data = out.MutableData<double>();
for (int i = 0; i < len; ++i) {
mutable_data[i] = std::stod(in.Data<std::string>()[i]);
}
} else if (std::is_same<DstType, int8_t>::value) {
auto* mutable_data = out.MutableData<int8_t>();
for (int i = 0; i < len; ++i) {
int temp_i = std::stoi(in.Data<std::string>()[i]);
mutable_data[i] = static_cast<int8_t>(temp_i);
}
} else if (std::is_same<DstType, uint8_t>::value) {
auto* mutable_data = out.MutableData<uint8_t>();
for (int i = 0; i < len; ++i) {
unsigned long temp_ui = std::stoul(in.Data<std::string>()[i]);
mutable_data[i] = static_cast<uint8_t>(temp_ui);
}
} else if (std::is_same<DstType, int16_t>::value) {
auto* mutable_data = out.MutableData<int16_t>();
for (int i = 0; i < len; ++i) {
int temp_i = std::stoi(in.Data<std::string>()[i]);
mutable_data[i] = static_cast<int16_t>(temp_i);
}
} else if (std::is_same<DstType, uint16_t>::value) {
auto* mutable_data = out.MutableData<uint16_t>();
for (int i = 0; i < len; ++i) {
unsigned long temp_ui = std::stoul(in.Data<std::string>()[i]);
mutable_data[i] = static_cast<uint16_t>(temp_ui);
}
} else if (std::is_same<DstType, int32_t>::value) {
auto* mutable_data = out.MutableData<int32_t>();
for (int i = 0; i < len; ++i) {
mutable_data[i] = std::stol(in.Data<std::string>()[i]);
}
} else if (std::is_same<DstType, uint32_t>::value) {
auto* mutable_data = out.MutableData<uint32_t>();
for (int i = 0; i < len; ++i) {
mutable_data[i] = std::stoul(in.Data<std::string>()[i]);
}
} else if (std::is_same<DstType, int64_t>::value) {
auto* mutable_data = out.MutableData<int64_t>();
for (int i = 0; i < len; ++i) {
mutable_data[i] = std::stoll(in.Data<std::string>()[i]);
}
} else if (std::is_same<DstType, uint64_t>::value) {
auto* mutable_data = out.MutableData<uint64_t>();
for (int i = 0; i < len; ++i) {
mutable_data[i] = std::stoull(in.Data<std::string>()[i]);
}
} else {
#ifdef ORT_NO_RTTI
ORT_THROW("Unsupported type in cast op");
#else
ORT_THROW("Unsupported type in cast op: from String to ", typeid(DstType).name());
#endif
typename std::enable_if<std::is_floating_point<DstType>::value, void>::type
CastFromString(const std::string& input, DstType& output) {
static_assert(sizeof(DstType) <= sizeof(double),
"largest supported floating point type is double");
output = gsl::narrow_cast<DstType>(std::stod(input));
}
template <typename DstType>
typename std::enable_if<std::is_integral<DstType>::value && std::is_unsigned<DstType>::value, void>::type
CastFromString(const std::string& input, DstType& output) {
static_assert(sizeof(DstType) <= sizeof(unsigned long long),
"largest supported unsigned integral type is unsigned long long");
output = gsl::narrow_cast<DstType>(std::stoull(input));
}
template <typename DstType>
typename std::enable_if<std::is_integral<DstType>::value && std::is_signed<DstType>::value, void>::type
CastFromString(const std::string& input, DstType& output) {
static_assert(sizeof(DstType) <= sizeof(long long),
"largest supported signed integral type is long long");
output = gsl::narrow_cast<DstType>(std::stoll(input));
}
// generic scalar X -> Y
template <typename SrcType, typename DstType>
struct ScalarDirectCaster {
void Cast(const SrcType& in, DstType& out) const {
out = static_cast<DstType>(in);
}
};
// scalar X -> string
template <typename SrcType>
struct ScalarDirectCaster<SrcType, std::string> {
void Cast(const SrcType& in, std::string& out) const {
CastToString<SrcType>(in, out);
}
};
// scalar string -> X
template <typename DstType>
struct ScalarDirectCaster<std::string, DstType> {
void Cast(const std::string& in, DstType& out) const {
CastFromString<DstType>(in, out);
}
};
// helper for indirect cast types
template <typename SrcType, typename DstType, typename IntermediateType>
struct ScalarIndirectCaster {
void Cast(const SrcType& in, DstType& out) const {
IntermediateType intermediate;
ScalarDirectCaster<SrcType, IntermediateType>{}.Cast(in, intermediate);
ScalarDirectCaster<IntermediateType, DstType>{}.Cast(intermediate, out);
}
};
template <typename SrcType, typename DstType, class Enable = void>
struct ScalarCaster;
template <typename SrcType, typename DstType>
struct ScalarCaster<
SrcType, DstType,
typename std::enable_if<AreAllDirectCastTypes<SrcType, DstType>::value>::type> {
void Cast(const SrcType& in, DstType& out) const {
ScalarDirectCaster<SrcType, DstType>{}.Cast(in, out);
}
};
template <typename SrcType, typename DstType>
struct ScalarCaster<
SrcType, DstType,
typename std::enable_if<!AreAllDirectCastTypes<SrcType, DstType>::value>::type> {
void Cast(const SrcType& in, DstType& out) const {
ScalarIndirectCaster<SrcType, DstType, float>{}.Cast(in, out);
}
};
// generic tensor X -> Y
template <typename SrcType, typename DstType>
struct TensorCaster {
void Cast(const Tensor& in, Tensor& out, const TensorShape& shape) const {
const std::ptrdiff_t shape_size = gsl::narrow<std::ptrdiff_t>(shape.Size());
const auto in_vector = ConstEigenVectorMap<SrcType>(in.Data<SrcType>(), shape_size);
auto out_vector = EigenVectorMap<DstType>(out.MutableData<DstType>(), shape_size);
out_vector = in_vector.unaryExpr([](const SrcType& in_scalar) {
DstType out_scalar;
ScalarCaster<SrcType, DstType>{}.Cast(in_scalar, out_scalar);
return out_scalar;
});
}
};
template <typename SrcType, typename DstType>
void CastStringTensor(const Tensor& in, Tensor& out, const TensorShape& shape) {
static_assert(std::is_same<SrcType, std::string>::value || std::is_same<DstType, std::string>::value,
"Either SrcType or DstType must be std::string.");
const std::ptrdiff_t shape_size = gsl::narrow<std::ptrdiff_t>(shape.Size());
const auto in_data = in.DataAsSpan<SrcType>();
const auto out_data = out.MutableDataAsSpan<DstType>();
for (std::ptrdiff_t i = 0; i < shape_size; ++i) {
ScalarCaster<SrcType, DstType>{}.Cast(in_data[i], out_data[i]);
}
}
#endif
} // namespace
// tensor X -> string
template <typename SrcType>
struct TensorCaster<SrcType, std::string> {
void Cast(const Tensor& in, Tensor& out, const TensorShape& shape) const {
CastStringTensor<SrcType, std::string>(in, out, shape);
}
};
// tensor string -> X
template <typename DstType>
struct TensorCaster<std::string, DstType> {
void Cast(const Tensor& in, Tensor& out, const TensorShape& shape) const {
CastStringTensor<std::string, DstType>(in, out, shape);
}
};
#if defined(_M_AMD64)
// tensor MLFloat16 -> float
template <>
struct TensorCaster<MLFloat16, float> {
void Cast(const Tensor& in, Tensor& out, const TensorShape& shape) const {
auto out_data = out.MutableData<float>();
auto in_data = in.Data<MLFloat16>();
const size_t shape_size = gsl::narrow<size_t>(shape.Size());
MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size);
}
};
#endif
class Cast final : public OpKernel {
public:
@ -201,46 +239,63 @@ class Cast final : public OpKernel {
Status Compute(OpKernelContext* context) const override;
private:
template <typename TSrc>
struct SrcDispatcher;
template <typename TSrc, typename TDest>
struct Dispatcher;
template <typename T>
struct StringDispatcher;
ONNX_NAMESPACE::TensorProto_DataType to_;
};
const std::vector<MLDataType> castOpTypeConstraints{
DataTypeImpl::GetTensorType<bool>(),
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<uint16_t>(),
DataTypeImpl::GetTensorType<uint32_t>(),
DataTypeImpl::GetTensorType<uint64_t>(),
DataTypeImpl::GetTensorType<int8_t>(),
DataTypeImpl::GetTensorType<int16_t>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>(),
#ifdef CAST_FLOAT16_ENABLED
DataTypeImpl::GetTensorType<MLFloat16>(),
DataTypeImpl::GetTensorType<BFloat16>(),
#endif
#ifdef CAST_STRING_ENABLED
DataTypeImpl::GetTensorType<std::string>(),
#endif
template <typename TSrc, typename TDst>
struct Dispatcher {
void operator()(const Tensor& src, Tensor& dst, const TensorShape& shape) {
TensorCaster<TSrc, TDst>{}.Cast(src, dst, shape);
}
};
template <typename TSrc>
struct SrcDispatcher {
void operator()(int32_t to, const Tensor& src, Tensor& dst, const TensorShape& shape) {
using DstTypes = mp_remove_if_q<EnabledDstTypes, mp_bind_front<std::is_same, TSrc>>;
utils::MLTypeCallDispatcherFromTypeList<DstTypes> dispatcher{to};
dispatcher.template InvokeWithLeadingTemplateArgs<Dispatcher, TypeList<TSrc>>(src, dst, shape);
}
};
Status Cast::Compute(OpKernelContext* context) const {
const Tensor* X = context->Input<Tensor>(0);
const TensorShape& shape = X->Shape();
Tensor* Y = context->Output(0, shape);
if (shape.Size() == 0) {
return Status::OK();
}
const auto from = X->GetElementType();
if (from == to_) {
// will copy if X and Y have different buffers
CopyCpuTensor(X, Y);
return Status::OK();
}
utils::MLTypeCallDispatcherFromTypeList<EnabledSrcTypes> dispatcher{from};
dispatcher.Invoke<SrcDispatcher>(to_, *X, *Y, shape);
return Status::OK();
}
const std::vector<MLDataType> castSrcTypeConstraints =
BuildKernelDefConstraintsFunctorFromTypeList<EnabledSrcTypes>{}();
const std::vector<MLDataType> castDstTypeConstraints =
BuildKernelDefConstraintsFunctorFromTypeList<EnabledDstTypes>{}();
} // namespace
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Cast,
6,
12,
KernelDefBuilder()
.TypeConstraint("T1", castOpTypeConstraints)
.TypeConstraint("T2", castOpTypeConstraints)
.TypeConstraint("T1", castSrcTypeConstraints)
.TypeConstraint("T2", castDstTypeConstraints)
.MayInplace(0, 0), // allocation planner will check input and output sizes match before inplacing
Cast);
@ -248,152 +303,9 @@ ONNX_CPU_OPERATOR_KERNEL(
Cast,
13,
KernelDefBuilder()
.TypeConstraint("T1", castOpTypeConstraints)
.TypeConstraint("T2", castOpTypeConstraints)
.TypeConstraint("T1", castSrcTypeConstraints)
.TypeConstraint("T2", castDstTypeConstraints)
.MayInplace(0, 0), // allocation planner will check input and output sizes match before inplacing
Cast);
// default dispatch
template <typename TSrc, typename TDst>
struct Cast::Dispatcher {
void operator()(const Tensor& src, Tensor& dst, const TensorShape& shape) {
CastData<TSrc, TDst>(src, dst, shape);
}
};
template <typename TSrc>
struct Cast::SrcDispatcher {
void operator()(int32_t to, const Tensor& src, Tensor& dst, const TensorShape& shape) {
utils::MLTypeCallDispatcherWithCarriedType<TSrc, Cast::Dispatcher,
float, double, int8_t, uint8_t, int16_t, uint16_t,
int32_t, uint32_t, int64_t, uint64_t, bool>
t_disp(to);
t_disp.Invoke(src, dst, shape);
}
};
#ifdef CAST_STRING_ENABLED
template <typename T>
struct Cast::StringDispatcher {
void operator()(bool to_string, const Tensor& src, Tensor& dst, const TensorShape& shape) {
if (to_string) {
CastToStringData<T>(src, dst, shape);
} else {
CastFromStringData<T>(src, dst, shape);
}
}
};
#endif
Status Cast::Compute(OpKernelContext* context) const {
Status status = Status::OK();
const Tensor* X = context->Input<Tensor>(0);
const TensorShape& shape = X->Shape();
Tensor* Y = context->Output(0, shape);
if (shape.Size() == 0) {
return status;
}
auto from = X->GetElementType();
if (from == to_) {
// will copy if X and Y have different buffers
CopyCpuTensor(X, Y);
return status;
}
#ifdef CAST_STRING_ENABLED
// special case strings
if (from == ONNX_NAMESPACE::TensorProto_DataType_STRING ||
to_ == ONNX_NAMESPACE::TensorProto_DataType_STRING) {
bool to_string = to_ == ONNX_NAMESPACE::TensorProto_DataType_STRING;
utils::MLTypeCallDispatcher<StringDispatcher,
float, double,
#ifdef CAST_FLOAT16_ENABLED
MLFloat16, /*BFloat16,*/
#endif
int8_t, uint8_t, int16_t, uint16_t, int32_t, uint32_t, int64_t, uint64_t,
bool>
t_disp(to_string ? from : to_);
t_disp.Invoke(to_string, *X, *Y, shape);
} else
#endif
{
auto do_cast = [](int32_t from, int32_t to, const Tensor& src, Tensor& dst, const TensorShape& shape) {
utils::MLTypeCallDispatcher<SrcDispatcher,
float, double, // MLFloat16 is special cased below
int8_t, uint8_t, int16_t, uint16_t, int32_t, uint32_t, int64_t, uint64_t, bool>
t_disp(from);
t_disp.Invoke(to, src, dst, shape);
};
#ifdef CAST_FLOAT16_ENABLED
// MLFloat16 needs special handling
if (from == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
if (to_ == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
CastData<MLFloat16, float>(*X, *Y, shape);
} else {
// need to cast to float first in a temporary buffer
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
auto tmp_buffer = IAllocator::MakeUniquePtr<float>(allocator, gsl::narrow<size_t>(shape.Size()));
Tensor tmp_tensor(DataTypeImpl::GetType<float>(), shape, tmp_buffer.get(), allocator->Info());
CastData<MLFloat16, float>(*X, tmp_tensor, shape);
do_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, to_, tmp_tensor, *Y, shape);
}
} else if (to_ == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
if (from == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
CastData<float, MLFloat16>(*X, *Y, shape);
} else {
// need to cast to float first in a temporary buffer
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
auto tmp_buffer = IAllocator::MakeUniquePtr<float>(allocator, gsl::narrow<size_t>(shape.Size()));
Tensor tmp_tensor(DataTypeImpl::GetType<float>(), shape, tmp_buffer.get(), allocator->Info());
do_cast(from, ONNX_NAMESPACE::TensorProto_DataType_FLOAT, *X, tmp_tensor, shape);
CastData<float, MLFloat16>(tmp_tensor, *Y, shape);
}
} else if (from == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
if (to_ == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
CastData<BFloat16, float>(*X, *Y, shape);
} else {
// need to cast to float first in a temporary buffer
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
auto tmp_buffer = IAllocator::MakeUniquePtr<float>(allocator, gsl::narrow<size_t>(shape.Size()));
Tensor tmp_tensor(DataTypeImpl::GetType<float>(), shape, tmp_buffer.get(), allocator->Info());
CastData<BFloat16, float>(*X, tmp_tensor, shape);
do_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, to_, tmp_tensor, *Y, shape);
}
} else if (to_ == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
if (from == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
CastData<float, BFloat16>(*X, *Y, shape);
} else {
// need to cast to float first in a temporary buffer
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
auto tmp_buffer = IAllocator::MakeUniquePtr<float>(allocator, gsl::narrow<size_t>(shape.Size()));
Tensor tmp_tensor(DataTypeImpl::GetType<float>(), shape, tmp_buffer.get(), allocator->Info());
do_cast(from, ONNX_NAMESPACE::TensorProto_DataType_FLOAT, *X, tmp_tensor, shape);
CastData<float, BFloat16>(tmp_tensor, *Y, shape);
}
}
else
#endif
{
do_cast(from, to_, *X, *Y, shape);
}
}
return status;
}
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -1,18 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/op_kernel.h"
#include "core/common/common.h"
#include "core/framework/tensor.h"
#include "core/util/math_cpuonly.h"
#include <cmath>
#include "core/common/common.h"
#include "core/framework/data_types_internal.h"
#include "core/framework/op_kernel.h"
#include "core/framework/tensor.h"
#include "core/providers/op_kernel_type_control.h"
#include "core/util/math_cpuonly.h"
namespace onnxruntime {
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#IsInf
namespace op_kernel_type_control {
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
kCpuExecutionProvider, kOnnxDomain, IsInf, Input, 0,
float, double);
} // namespace op_kernel_type_control
class IsInf final : public OpKernel {
public:
using EnabledTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, IsInf, Input, 0);
explicit IsInf(const OpKernelInfo& info);
Status Compute(OpKernelContext* context) const override;
@ -25,8 +35,8 @@ ONNX_CPU_OPERATOR_KERNEL(
IsInf,
10,
KernelDefBuilder()
.TypeConstraint("T1", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>()})
.TypeConstraint(
"T1", BuildKernelDefConstraintsFunctorFromTypeList<IsInf::EnabledTypes>{}())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
IsInf);
@ -39,32 +49,34 @@ IsInf::IsInf(const OpKernelInfo& info) : OpKernel(info) {
namespace isinf_internal {
template <class T>
void ComputeImpl(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) {
const auto total_items = X.Shape().Size();
auto output_data = Y.template MutableData<bool>();
struct ComputeDispatchTarget {
void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const {
const auto total_items = X.Shape().Size();
auto output_data = Y.template MutableData<bool>();
if (detect_positive && detect_negative) {
EigenMap<bool>(Y) = EigenMap<T>(X).array().isInf();
} else if (detect_positive) {
auto input_data = X.template Data<T>();
auto end_data = input_data + total_items;
std::transform(
input_data, end_data, output_data, [](T v) {
return (v == std::numeric_limits<T>::infinity());
});
if (detect_positive && detect_negative) {
EigenMap<bool>(Y) = EigenMap<T>(X).array().isInf();
} else if (detect_positive) {
auto input_data = X.template Data<T>();
auto end_data = input_data + total_items;
std::transform(
input_data, end_data, output_data, [](T v) {
return (v == std::numeric_limits<T>::infinity());
});
} else if (detect_negative) {
auto input_data = X.template Data<T>();
auto end_data = input_data + total_items;
std::transform(
input_data, end_data, output_data, [](T v) {
return (v == -std::numeric_limits<T>::infinity());
});
} else {
// all false
memset(output_data, false, total_items);
} else if (detect_negative) {
auto input_data = X.template Data<T>();
auto end_data = input_data + total_items;
std::transform(
input_data, end_data, output_data, [](T v) {
return (v == -std::numeric_limits<T>::infinity());
});
} else {
// all false
memset(output_data, false, total_items);
}
}
}
};
} // namespace isinf_internal
Status IsInf::Compute(OpKernelContext* context) const {
@ -75,14 +87,8 @@ Status IsInf::Compute(OpKernelContext* context) const {
using namespace isinf_internal;
if (X.IsDataType<float>()) {
ComputeImpl<float>(X, Y, detect_positive_ != 0, detect_negative_ != 0);
} else if (X.IsDataType<double>()) {
ComputeImpl<double>(X, Y, detect_positive_ != 0, detect_negative_ != 0);
} else {
// should not reach this as no kernel is registered for this condition to be triggered - just an additional safety check
ORT_THROW("Data type X must be float or double, but instead got ", X.DataType());
}
utils::MLTypeCallDispatcherFromTypeList<EnabledTypes> dispatcher{X.GetElementType()};
dispatcher.Invoke<ComputeDispatchTarget>(X, Y, detect_positive_ != 0, detect_negative_ != 0);
return Status::OK();
}

View file

@ -81,7 +81,7 @@ Status Pad<T>::ComputeInternal(OpKernelContext* ctx) const {
}
}
T raw_value(0);
T raw_value{};
const Tensor* value_tensor = ctx->Input<Tensor>(2);
if (nullptr != value_tensor) {
ORT_ENFORCE(utils::IsPrimitiveDataType<T>(value_tensor->DataType()) &&

View file

@ -0,0 +1,232 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <cstdint>
#include <tuple>
#include "boost/mp11.hpp"
#include "core/common/type_list.h"
#include "core/framework/data_types.h"
/**
* These utilities provide a way to control what types are enabled for an Op kernel implementation.
*
* At a high level, we have the notion of supported, allowed, and enabled types.
* - Supported types are the types that the Op kernel implementation supports by default.
* - Allowed types are the types for which support is requested (for example, by external configuration).
* - Enabled types are the types that are supported in the actual, compiled implementation. They are obtained from the
* intersection of supported and allowed types.
*
* The types are associated with an Op kernel argument. It is also possible to specify a global list of allowed types.
*
* Use of these utilities is optional. They are useful for cases where one registered Op kernel handles multiple types.
*
* See the macros below for usage details.
*/
namespace onnxruntime {
namespace op_kernel_type_control {
enum class OpArgDirection {
Input,
Output
};
using OpArgIndex = size_t;
namespace tags {
// a tag that identifies the target (Op argument) of the specified types
template <typename OpTag, OpArgDirection ArgDirection, OpArgIndex ArgIndex>
struct OpArg {};
// a tag that indicates the supported types for a particular Op argument, identified by OpArgTag,
// for a kernel in a particular provider, identified by ProviderTag
template <typename OpArgTag, typename ProviderTag>
struct Supported {};
// a tag that indicates the allowed types for a particular Op argument, identified by OpArgTag
template <typename OpArgTag>
struct Allowed {};
// a tag that indicates the globally allowed types
struct GlobalAllowed {};
} // namespace tags
// optionally holds a list of types associated with a tag class
// if types are defined, the data member 'types' should contain them in a type list
// otherwise, if no types are defined (distinct from an empty list of types), there should be no data member 'types'
// see the tags in onnxruntime::op_kernel_type_control::tags for intended uses
template <typename Tag>
struct TypesHolder {};
/**
* Provides a type list of enabled types via the 'types' data member.
* Enabled types are the set intersection of supported and allowed types.
*
* @tparam SupportedTypesHolder A 'TypesHolder' with a list of supported types.
* @tparam AllowedTypesHolders A list of 'TypesHolder's each with an optional list of allowed types.
*/
template <typename SupportedTypesHolder, typename AllowedTypesHolders>
struct EnabledTypes {
private:
static_assert(boost::mp11::mp_is_list<AllowedTypesHolders>::value,
"AllowedTypesHolders must be a type list.");
template <typename T>
using GetTypesMember = typename T::types;
// checks whether T has data member 'types'
template <typename T>
using HasTypesMember = boost::mp11::mp_valid<GetTypesMember, T>;
static_assert(HasTypesMember<SupportedTypesHolder>::value,
"SupportedTypesHolder must have a 'types' data member.");
// the allowed type lists to consider
// for each element of AllowedTypesHolders, get and include a 'types' data member if present
using AllowedTypesMembers =
boost::mp11::mp_transform<
GetTypesMember,
boost::mp11::mp_filter<
HasTypesMember,
AllowedTypesHolders>>;
// collect supported and allowed type lists
using TypeListsToConsider =
boost::mp11::mp_push_front<AllowedTypesMembers, GetTypesMember<SupportedTypesHolder>>;
static_assert(boost::mp11::mp_all_of<TypeListsToConsider, boost::mp11::mp_is_list>::value,
"All 'types' data members must be type lists.");
// converts type list L into a type set (type list with unique elements)
template <typename L>
using MakeSet =
boost::mp11::mp_apply<
boost::mp11::mp_set_push_back,
boost::mp11::mp_append<TypeList<TypeList<>>, L>>;
// type lists converted to type sets
using TypeSetsToConsider = boost::mp11::mp_transform<MakeSet, TypeListsToConsider>;
public:
using types = boost::mp11::mp_apply<boost::mp11::mp_set_intersection, TypeSetsToConsider>;
};
} // namespace op_kernel_type_control
} // namespace onnxruntime
// INTERNAL
// the class name of a tag type identifying an Op for the purposes of type control
#define ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_TAG_CLASS_NAME(OpDomain, OpName) \
TypeControl_##OpDomain##_##OpName##_OpTag
// INTERNAL
// the class name of a tag type identifying a provider for the purposes of type control
#define ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider) \
TypeControl_##OpProvider##_ProviderTag
// INTERNAL
// a tag type identifying an Op argument
#define ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_KERNEL_ARG_TAG( \
OpDomain, OpName, ArgDirection, ArgIndex) \
::onnxruntime::op_kernel_type_control::tags::OpArg< \
::onnxruntime::op_kernel_type_control:: \
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_TAG_CLASS_NAME(OpDomain, OpName), \
::onnxruntime::op_kernel_type_control::OpArgDirection::ArgDirection, \
ArgIndex>
// public macros
/**
* Specifies a supported set of types for a given Op kernel argument.
* This should be specified with the Op kernel implementation.
*
* Note: This should be called from the onnxruntime::op_kernel_type_control namespace.
*
* @param OpProvider The Op provider.
* @param OpDomain The Op domain.
* @param OpName The Op name.
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
* @param ArgIndex Index of the given Op kernel argument.
* @param ... The types.
*/
#define ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES( \
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex, ...) \
class ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_TAG_CLASS_NAME(OpDomain, OpName); \
class ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider); \
template <> \
struct TypesHolder< \
::onnxruntime::op_kernel_type_control::tags::Supported< \
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_KERNEL_ARG_TAG(OpDomain, OpName, ArgDirection, ArgIndex), \
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider)>> { \
using types = ::onnxruntime::TypeList<__VA_ARGS__>; \
};
/**
* TypeList type with the enabled types for a given Op kernel argument.
*
* @param OpProvider The Op provider.
* @param OpDomain The Op domain.
* @param OpName The Op name.
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
* @param ArgIndex Index of the given Op kernel argument.
*/
#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( \
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex) \
::onnxruntime::op_kernel_type_control::EnabledTypes< \
::onnxruntime::op_kernel_type_control::TypesHolder< \
::onnxruntime::op_kernel_type_control::tags::Supported< \
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_KERNEL_ARG_TAG(OpDomain, OpName, ArgDirection, ArgIndex), \
::onnxruntime::op_kernel_type_control:: \
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider)>>, \
::onnxruntime::TypeList< \
::onnxruntime::op_kernel_type_control::TypesHolder< \
::onnxruntime::op_kernel_type_control::tags::Allowed< \
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_KERNEL_ARG_TAG(OpDomain, OpName, ArgDirection, ArgIndex)>>, \
::onnxruntime::op_kernel_type_control::TypesHolder< \
::onnxruntime::op_kernel_type_control::tags::GlobalAllowed>>>::types
/**
* std::tuple type with the enabled types for a given Op kernel argument.
*
* @param OpProvider The Op provider.
* @param OpDomain The Op domain.
* @param OpName The Op name.
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
* @param ArgIndex Index of the given Op kernel argument.
*/
#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE( \
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex) \
::boost::mp11::mp_rename< \
ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( \
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex, SupportedTypeList), \
std::tuple>
/**
* Usage example:
*
* In MyProvider provider's implementation of MyOp kernel:
*
* // specify supported types, i.e., the full set of types that can be enabled
* ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
* MyProvider, DomainContainingMyOp, MyOp, Input, 0,
* int, float, double);
*
* // get enabled types
* using MyOpFirstInputEnabledTypes =
* ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(MyProvider, DomainContainingMyOp, MyOp, Input, 0)
*
* ...
*
* // in the implementation, we can dispatch to the enabled types
* utils::MLTypeCallDispatcherFromTypeList<MyOpFirstInputEnabledTypes> dispatcher{firstInputRuntimeType};
* ...
*/
// all allowed type specifications should be contained in the following file
#include "core/providers/op_kernel_type_control_overrides.inc"

View file

@ -0,0 +1,61 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
namespace onnxruntime {
namespace op_kernel_type_control {
/**
* Specifies an allowed set of types for a given Op kernel argument.
* This can optionally be specified to further limit the enabled types.
*
* Note: This should be called from the onnxruntime::op_kernel_type_control namespace.
*
* @param OpDomain The Op domain.
* @param OpName The Op name.
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
* @param ArgIndex Index of the given Op kernel argument.
* @param ... The types.
*/
#define ORT_SPECIFY_OP_KERNEL_ARG_ALLOWED_TYPES( \
OpDomain, OpName, ArgDirection, ArgIndex, ...) \
class ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_TAG_CLASS_NAME(OpDomain, OpName); \
template <> \
struct TypesHolder< \
::onnxruntime::op_kernel_type_control::tags::Allowed< \
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_KERNEL_ARG_TAG( \
OpDomain, OpName, ArgDirection, ArgIndex)>> { \
using types = ::onnxruntime::TypeList<__VA_ARGS__>; \
};
/**
* Specifies an allowed set of types globally (applicable to any Op kernel argument).
* This can optionally be specified to further limit the enabled types.
*
* Note: This should be called from the onnxruntime::op_kernel_type_control namespace.
*
* @param ... The types.
*/
#define ORT_SPECIFY_OP_KERNEL_GLOBAL_ALLOWED_TYPES(...) \
template <> \
struct TypesHolder< \
::onnxruntime::op_kernel_type_control::tags::GlobalAllowed> { \
using types = ::onnxruntime::TypeList<__VA_ARGS__>; \
};
// Examples:
// Specify allowed types per Op kernel arg:
// ORT_SPECIFY_OP_KERNEL_ARG_ALLOWED_TYPES(kOnnxDomain, Cast, Input, 0, float, int64_t);
// ORT_SPECIFY_OP_KERNEL_ARG_ALLOWED_TYPES(kOnnxDomain, Cast, Output, 0, float, int64_t);
// Specify allowed types globally:
// ORT_SPECIFY_OP_KERNEL_GLOBAL_ALLOWED_TYPES(float, double, int32_t)
// specify allowed types here
// @@insertion_point_begin(allowed_types)@@
// @@insertion_point_end(allowed_types)@@
#undef ORT_SPECIFY_OP_KERNEL_ARG_ALLOWED_TYPES
#undef ORT_SPECIFY_OP_KERNEL_GLOBAL_ALLOWED_TYPES
} // namespace op_kernel_type_control
} // namespace onnxruntime

View file

@ -126,7 +126,7 @@ class RandomValueGenerator {
template <class T>
inline std::vector<T> FillZeros(const std::vector<int64_t>& dims) {
std::vector<T> val(detail::SizeFromDims(dims), T(0));
std::vector<T> val(detail::SizeFromDims(dims), T{});
return val;
}

View file

@ -27,7 +27,7 @@ void VerifyTensorProtoFileData(const PathString& tensor_proto_path, gsl::span<co
std::vector<T> actual_data{};
actual_data.resize(expected_data.size());
ASSERT_STATUS_OK(utils::UnpackTensor(tensor_proto, actual_data.data(), actual_data.size()));
ASSERT_STATUS_OK(utils::UnpackTensor(tensor_proto, Path{}, actual_data.data(), actual_data.size()));
ASSERT_EQ(gsl::make_span(actual_data), expected_data);
}

View file

@ -148,7 +148,7 @@ TEST(ConstantOfShape, TypeTests) {
//RunTypedTest(TensorProto::INT16, int16_t(16));
RunTypedTest(TensorProto::FLOAT, 1.f);
RunTypedTest(TensorProto::FLOAT16, MLFloat16(5));
RunTypedTest(TensorProto::FLOAT16, MLFloat16(static_cast<uint16_t>(5)));
RunTypedTest(TensorProto::DOUBLE, 1.0);
RunTypedTest(TensorProto::INT32, int32_t(32));
RunTypedTest(TensorProto::INT64, int64_t(64));

View file

@ -291,8 +291,8 @@ TEST(TensorOpTest, CastFromFloat16) {
TEST(TensorOpTest, CastFromString) {
const std::vector<int64_t> shape{2, 2, 2};
std::initializer_list<std::string> string_data = {"-inf", "+INF", "0.9767611f", "0.28280696f",
"-0.12019656f", "5.0f", "NaN", "nan"};
std::initializer_list<std::string> string_data = {"-inf", "+INF", "0.9767611", "0.28280696",
"-0.12019656", "5.0", "NaN", "nan"};
const std::initializer_list<float> float_output = {-(std::numeric_limits<float>::infinity()), std::numeric_limits<float>::infinity(),
0.9767611f, 0.28280696f,
-0.12019656f, 5.0f, NAN, NAN};

View file

@ -152,9 +152,10 @@ TEST(TransposeOpTest, TwoDim_mlfloat16) {
std::vector<int64_t> perm = {1, 0};
std::vector<int64_t> expected_shape({3, 2});
std::initializer_list<MLFloat16> expected_vals = {MLFloat16(1), MLFloat16(4),
MLFloat16(2), MLFloat16(5),
MLFloat16(3), MLFloat16(6)};
std::initializer_list<MLFloat16> expected_vals =
{MLFloat16{static_cast<uint16_t>(1)}, MLFloat16{static_cast<uint16_t>(4)},
MLFloat16{static_cast<uint16_t>(2)}, MLFloat16{static_cast<uint16_t>(5)},
MLFloat16{static_cast<uint16_t>(3)}, MLFloat16{static_cast<uint16_t>(6)}};
TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, false);
}

View file

@ -0,0 +1,71 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/op_kernel_type_control.h"
#include <cstdint>
#include "boost/mp11.hpp"
#include "core/common/type_list.h"
namespace onnxruntime {
namespace test {
template <typename A, typename B>
struct TypeSetsEqual {
private:
static_assert(boost::mp11::mp_is_set<A>::value && boost::mp11::mp_is_set<B>::value,
"A and B must both be sets.");
using ABIntersection = boost::mp11::mp_set_intersection<A, B>;
public:
static constexpr bool value =
(boost::mp11::mp_size<A>::value == boost::mp11::mp_size<B>::value) &&
(boost::mp11::mp_size<ABIntersection>::value == boost::mp11::mp_size<A>::value);
};
// test types to match op_kernel_type_control::TypesHolder
template <typename... T>
struct TestTypesHolder {
using types = TypeList<T...>;
};
struct TestTypesHolderUnspecified {
};
// supported + allowed for Op
static_assert(
TypeSetsEqual<
op_kernel_type_control::EnabledTypes<
TestTypesHolder<int32_t, int64_t, float, double>,
TypeList<
TestTypesHolder<float, int64_t, char>,
TestTypesHolderUnspecified>>::types,
TypeList<int64_t, float>>::value,
"unexpected enabled types: supported + allowed + unspecified allowed");
// supported + allowed for Op + allowed globally
static_assert(
TypeSetsEqual<
op_kernel_type_control::EnabledTypes<
TestTypesHolder<int32_t, int64_t, float, double>,
TypeList<
TestTypesHolder<float, int64_t, char>,
TestTypesHolder<int64_t>>>::types,
TypeList<int64_t>>::value,
"unexpected enabled types: supported + allowed + allowed");
// supported
static_assert(
TypeSetsEqual<
op_kernel_type_control::EnabledTypes<
TestTypesHolder<int32_t, int64_t, float, double>,
TypeList<
TestTypesHolderUnspecified,
TestTypesHolderUnspecified>>::types,
TypeList<int32_t, int64_t, float, double>>::value,
"unexpected enabled types: supported + unspecified allowed + unspecified allowed");
} // namespace test
} // namespace onnxruntime