mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-20 02:07:56 +00:00
Op kernel type reduction infrastructure. (#6466)
Add infrastructure to support type reduction in Op kernel implementations. Update Cast and IsInf CPU kernels to use it.
This commit is contained in:
parent
91b19b8364
commit
d850fa63bf
24 changed files with 760 additions and 390 deletions
3
.gitmodules
vendored
3
.gitmodules
vendored
|
|
@ -59,6 +59,9 @@
|
|||
[submodule "cmake/external/optional-lite"]
|
||||
path = cmake/external/optional-lite
|
||||
url = https://github.com/martinmoene/optional-lite.git
|
||||
[submodule "cmake/external/mp11"]
|
||||
path = cmake/external/mp11
|
||||
url = https://github.com/boostorg/mp11.git
|
||||
[submodule "cmake/external/coremltools"]
|
||||
path = cmake/external/coremltools
|
||||
url = https://github.com/apple/coremltools.git
|
||||
|
|
|
|||
|
|
@ -218,6 +218,16 @@
|
|||
"comments": "git submodule at cmake/external/mimalloc"
|
||||
}
|
||||
},
|
||||
{
|
||||
"component": {
|
||||
"type": "git",
|
||||
"git": {
|
||||
"commitHash": "21cace4e574180ba64d9307a5e4ea9e5e94d3e8d",
|
||||
"repositoryUrl": "https://github.com/boostorg/mp11.git"
|
||||
},
|
||||
"comments": "git submodule at cmake/external/mp11"
|
||||
}
|
||||
},
|
||||
{
|
||||
"component": {
|
||||
"type": "git",
|
||||
|
|
|
|||
|
|
@ -662,10 +662,6 @@ set(ONNXRUNTIME_INCLUDE_DIR ${REPO_ROOT}/include/onnxruntime)
|
|||
|
||||
add_subdirectory(external/date EXCLUDE_FROM_ALL)
|
||||
|
||||
if(onnxruntime_PREFER_SYSTEM_LIB)
|
||||
find_package(re2)
|
||||
endif()
|
||||
|
||||
set(SAFEINT_INCLUDE_DIR ${REPO_ROOT}/cmake/external/SafeInt)
|
||||
add_library(safeint_interface INTERFACE)
|
||||
target_include_directories(safeint_interface INTERFACE ${SAFEINT_INCLUDE_DIR})
|
||||
|
|
@ -675,6 +671,11 @@ if(onnxruntime_DISABLE_EXCEPTIONS)
|
|||
add_compile_definitions(optional_CONFIG_NO_EXCEPTIONS=1)
|
||||
endif()
|
||||
|
||||
add_subdirectory(external/mp11 EXCLUDE_FROM_ALL)
|
||||
|
||||
if(onnxruntime_PREFER_SYSTEM_LIB)
|
||||
find_package(re2)
|
||||
endif()
|
||||
if(NOT TARGET re2::re2)
|
||||
add_subdirectory(external/re2 EXCLUDE_FROM_ALL)
|
||||
set_target_properties(re2 PROPERTIES FOLDER "External/re2")
|
||||
|
|
|
|||
1
cmake/external/mp11
vendored
Submodule
1
cmake/external/mp11
vendored
Submodule
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 21cace4e574180ba64d9307a5e4ea9e5e94d3e8d
|
||||
|
|
@ -105,6 +105,8 @@ target_include_directories(onnxruntime_common
|
|||
$<TARGET_PROPERTY:safeint_interface,INTERFACE_INCLUDE_DIRECTORIES>
|
||||
${OPTIONAL_LITE_INCLUDE_DIR})
|
||||
|
||||
target_link_libraries(onnxruntime_common Boost::mp11)
|
||||
|
||||
if(NOT WIN32)
|
||||
target_include_directories(onnxruntime_common PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/external/nsync/public")
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ file(GLOB onnxruntime_cpu_featurizers_cc_srcs CONFIGURE_DEPENDS
|
|||
file(GLOB onnxruntime_providers_common_srcs CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/*.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/*.cc"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/op_kernel_type_control_overrides.inc"
|
||||
)
|
||||
|
||||
if(onnxruntime_USE_NUPHAR)
|
||||
|
|
|
|||
|
|
@ -397,6 +397,7 @@ target_include_directories(winml_lib_image PRIVATE ${REPO_ROOT}/cmake/external/p
|
|||
target_include_directories(winml_lib_image PRIVATE ${ONNXRUNTIME_INCLUDE_DIR}/core/platform/windows)
|
||||
target_include_directories(winml_lib_image PRIVATE ${REPO_ROOT}/cmake/external/flatbuffers/include)
|
||||
target_include_directories(winml_lib_image PRIVATE ${REPO_ROOT}/cmake/external/optional-lite/include)
|
||||
target_include_directories(winml_lib_image PRIVATE ${REPO_ROOT}/cmake/external/mp11/include)
|
||||
|
||||
# Properties
|
||||
set_target_properties(winml_lib_image
|
||||
|
|
@ -512,6 +513,7 @@ target_include_directories(winml_lib_api PRIVATE ${REPO_ROOT}/cmake/external/gsl
|
|||
target_include_directories(winml_lib_api PRIVATE ${REPO_ROOT}/cmake/external/SafeInt)
|
||||
target_include_directories(winml_lib_api PRIVATE ${REPO_ROOT}/cmake/external/flatbuffers/include)
|
||||
target_include_directories(winml_lib_api PRIVATE ${REPO_ROOT}/cmake/external/optional-lite/include)
|
||||
target_include_directories(winml_lib_api PRIVATE ${REPO_ROOT}/cmake/external/mp11/include)
|
||||
|
||||
# Properties
|
||||
set_target_properties(winml_lib_api
|
||||
|
|
@ -594,6 +596,7 @@ target_include_directories(winml_lib_api_experimental PRIVATE ${REPO_ROOT}/cmake
|
|||
target_include_directories(winml_lib_api_experimental PRIVATE ${REPO_ROOT}/cmake/external/SafeInt)
|
||||
target_include_directories(winml_lib_api_experimental PRIVATE ${REPO_ROOT}/cmake/external/flatbuffers/include)
|
||||
target_include_directories(winml_lib_api_experimental PRIVATE ${REPO_ROOT}/cmake/external/optional-lite/include)
|
||||
target_include_directories(winml_lib_api_experimental PRIVATE ${REPO_ROOT}/cmake/external/mp11/include)
|
||||
|
||||
# Properties
|
||||
set_target_properties(winml_lib_api_experimental
|
||||
|
|
@ -748,6 +751,7 @@ target_include_directories(winml_dll PRIVATE ${REPO_ROOT}/cmake/external/eigen)
|
|||
target_include_directories(winml_dll PRIVATE ${REPO_ROOT}/cmake/external/SafeInt)
|
||||
target_include_directories(winml_dll PRIVATE ${REPO_ROOT}/cmake/external/flatbuffers/include)
|
||||
target_include_directories(winml_dll PRIVATE ${REPO_ROOT}/cmake/external/optional-lite/include)
|
||||
target_include_directories(winml_dll PRIVATE ${REPO_ROOT}/cmake/external/mp11/include)
|
||||
|
||||
# Properties
|
||||
set_target_properties(winml_dll
|
||||
|
|
|
|||
|
|
@ -163,7 +163,7 @@ function (get_winml_test_model_src
|
|||
"${winml_test_src_path}/model/*.cpp")
|
||||
set(${output_winml_test_model_src} ${winml_test_model_src} PARENT_SCOPE)
|
||||
set(${winml_test_model_libs} onnx_test_data_proto onnx_test_runner_common onnxruntime_common onnxruntime_mlas
|
||||
onnxruntime_graph onnxruntime_test_utils onnxruntime_framework onnxruntime_flatbuffers PARENT_SCOPE)
|
||||
onnxruntime_graph onnxruntime_test_utils onnxruntime_framework onnxruntime_util onnxruntime_flatbuffers PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
file(GLOB winml_test_common_src CONFIGURE_DEPENDS
|
||||
|
|
|
|||
|
|
@ -52,11 +52,12 @@ class NonTensorTypeBase;
|
|||
class PrimitiveDataTypeBase;
|
||||
|
||||
// MLFloat16
|
||||
union MLFloat16 {
|
||||
struct MLFloat16 {
|
||||
uint16_t val;
|
||||
|
||||
explicit MLFloat16(uint16_t x) : val(x) {}
|
||||
MLFloat16() : val(0) {}
|
||||
explicit MLFloat16(uint16_t x) : val(x) {}
|
||||
explicit MLFloat16(float f);
|
||||
|
||||
// Taken from https://stackoverflow.com/a/60047308/12627730
|
||||
float AsFloat(uint32_t x) const {
|
||||
|
|
|
|||
|
|
@ -3,12 +3,16 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "boost/mp11.hpp"
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/common/type_list.h"
|
||||
#include "core/framework/data_types.h"
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
|
||||
|
|
@ -341,25 +345,54 @@ class MLTypeCallDispatcherRet {
|
|||
}
|
||||
};
|
||||
|
||||
// Version of the MLTypeDispatcher that has an input type which is passed through ('carried')
|
||||
// as the first type parameter in the call to Fn when dispatching.
|
||||
template <typename TCarried, template <typename, typename> class Fn, typename... Types>
|
||||
class MLTypeCallDispatcherWithCarriedType {
|
||||
// Version of MLTypeCallDispatcher that takes supported types as class-level template parameters.
|
||||
// This enables easier use with type list representations of the supported types.
|
||||
// The invocation-related template parameters like Fn move to the individual Invoke() methods.
|
||||
// TODO consolidate this with the other MLTypeCallDispatcher classes
|
||||
// can add additional methods to cover their usages, but need to update call sites
|
||||
template <typename... Types>
|
||||
class MLTypeCallDispatcher2 {
|
||||
static_assert(boost::mp11::mp_is_set<TypeList<Types...>>::value,
|
||||
"MLTypeCallDispatcher requires a set of unique types.");
|
||||
|
||||
int32_t dt_type_;
|
||||
|
||||
public:
|
||||
explicit MLTypeCallDispatcherWithCarriedType(int32_t dt_type) noexcept : dt_type_(dt_type) {}
|
||||
explicit MLTypeCallDispatcher2(int32_t dt_type) noexcept : dt_type_(dt_type) {}
|
||||
|
||||
template <typename... Args>
|
||||
template <template <typename> class Fn, typename... Args>
|
||||
void Invoke(Args&&... args) const {
|
||||
mltype_dispatcher_internal::CallableDispatchableHelper helper(dt_type_);
|
||||
int results[] = {0, helper.template Invoke<Types>(Fn<TCarried, Types>(), std::forward<Args>(args)...)...};
|
||||
ORT_UNUSED_PARAMETER(results);
|
||||
ORT_ENFORCE(helper.called_ < 2, "Check for duplicate types in MLTypeCallDispatcher");
|
||||
|
||||
static_cast<void>(std::array<int, sizeof...(Types)>{
|
||||
helper.template Invoke<Types>(Fn<Types>(), std::forward<Args>(args)...)...});
|
||||
|
||||
// avoid "unused parameter" warning for the case where Types is empty
|
||||
static_cast<void>(std::array<int, sizeof...(Args)>{(ORT_UNUSED_PARAMETER(args), 0)...});
|
||||
|
||||
ORT_ENFORCE(helper.called_ == 1, "Unsupported data type: ", dt_type_);
|
||||
}
|
||||
|
||||
template <template <typename...> class Fn, typename LeadingTemplateArgTypeList, typename... Args>
|
||||
void InvokeWithLeadingTemplateArgs(Args&&... args) const {
|
||||
mltype_dispatcher_internal::CallableDispatchableHelper helper(dt_type_);
|
||||
|
||||
static_cast<void>(std::array<int, sizeof...(Types)>{
|
||||
helper.template Invoke<Types>(
|
||||
boost::mp11::mp_apply<Fn, boost::mp11::mp_push_back<LeadingTemplateArgTypeList, Types>>(),
|
||||
std::forward<Args>(args)...)...});
|
||||
|
||||
// avoid "unused parameter" warning for the case where Types is empty
|
||||
static_cast<void>(std::array<int, sizeof...(Args)>{(ORT_UNUSED_PARAMETER(args), 0)...});
|
||||
|
||||
ORT_ENFORCE(helper.called_ == 1, "Unsupported data type: ", dt_type_);
|
||||
}
|
||||
};
|
||||
|
||||
// the type MLTypeCallDispatcher2<T...> given a type list L<T...>
|
||||
template <typename L>
|
||||
using MLTypeCallDispatcherFromTypeList = boost::mp11::mp_apply<MLTypeCallDispatcher2, L>;
|
||||
|
||||
namespace data_types_internal {
|
||||
|
||||
enum class ContainerType : uint16_t {
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
|
||||
#include <functional>
|
||||
|
||||
#include "boost/mp11.hpp"
|
||||
|
||||
#include "core/common/exceptions.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/common/status.h"
|
||||
|
|
@ -481,4 +483,17 @@ inline std::vector<MLDataType> BuildKernelDefConstraints() {
|
|||
return {DataTypeImpl::GetTensorType<T>(), DataTypeImpl::GetTensorType<Types>()...};
|
||||
}
|
||||
|
||||
// functor that calls BuildKernelDefConstraints()
|
||||
template <typename... Types>
|
||||
struct BuildKernelDefConstraintsFunctor {
|
||||
std::vector<MLDataType> operator()() const {
|
||||
return BuildKernelDefConstraints<Types...>();
|
||||
}
|
||||
};
|
||||
|
||||
// the type BuildKernelDefConstraintsFunctor<T...> given a type list L<T...>
|
||||
template <typename L>
|
||||
using BuildKernelDefConstraintsFunctorFromTypeList =
|
||||
boost::mp11::mp_apply<BuildKernelDefConstraintsFunctor, L>;
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
12
onnxruntime/core/common/type_list.h
Normal file
12
onnxruntime/core/common/type_list.h
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
// this type represents a compile-time list of types
|
||||
template <typename... T>
|
||||
struct TypeList {};
|
||||
|
||||
}
|
||||
|
|
@ -7,6 +7,7 @@
|
|||
#include "core/framework/sparse_tensor.h"
|
||||
#include "core/framework/data_types_internal.h"
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
#include "core/util/math.h"
|
||||
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic push
|
||||
|
|
@ -21,6 +22,9 @@
|
|||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
MLFloat16::MLFloat16(float f) : val{math::floatToHalf(f)} {}
|
||||
|
||||
// Return the MLDataType used for a generic Tensor
|
||||
template <>
|
||||
MLDataType DataTypeImpl::GetType<Tensor>() {
|
||||
|
|
|
|||
|
|
@ -1,193 +1,231 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <cstddef>
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
|
||||
#include "boost/mp11.hpp"
|
||||
|
||||
#include "gsl/gsl"
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/common/type_list.h"
|
||||
#include "core/framework/data_types.h"
|
||||
#include "core/framework/data_types_internal.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "core/util/math.h"
|
||||
#include "core/providers/op_kernel_type_control.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
|
||||
#include "Eigen/src/Core/arch/Default/Half.h"
|
||||
|
||||
#if defined(_M_AMD64)
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#endif
|
||||
|
||||
// FUTURE:
|
||||
// Float16 and String have expensive special cased handling. Enable by default, but provide an easy way to disable
|
||||
// in the future if needed. Disabling both saves ~50KB.
|
||||
#define CAST_FLOAT16_ENABLED
|
||||
#define CAST_STRING_ENABLED
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace boost::mp11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace op_kernel_type_control {
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
|
||||
kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0,
|
||||
bool,
|
||||
float, double,
|
||||
uint8_t, uint16_t, uint32_t, uint64_t,
|
||||
int8_t, int16_t, int32_t, int64_t,
|
||||
MLFloat16, BFloat16,
|
||||
std::string);
|
||||
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
|
||||
kCpuExecutionProvider, kOnnxDomain, Cast, Output, 0,
|
||||
bool,
|
||||
float, double,
|
||||
uint8_t, uint16_t, uint32_t, uint64_t,
|
||||
int8_t, int16_t, int32_t, int64_t,
|
||||
MLFloat16, BFloat16,
|
||||
std::string);
|
||||
} // namespace op_kernel_type_control
|
||||
|
||||
namespace {
|
||||
template <typename SrcType, typename DstType>
|
||||
inline void CastData(const Tensor& in, Tensor& out, const TensorShape& shape) {
|
||||
ptrdiff_t shape_size = gsl::narrow<ptrdiff_t>(shape.Size());
|
||||
auto in_vector = ConstEigenVectorMap<SrcType>(in.Data<SrcType>(), shape_size);
|
||||
auto output_vector = EigenVectorMap<DstType>(out.MutableData<DstType>(), shape_size);
|
||||
output_vector = in_vector.template cast<DstType>();
|
||||
}
|
||||
|
||||
#ifdef CAST_FLOAT16_ENABLED
|
||||
template <>
|
||||
inline void CastData<float, MLFloat16>(const Tensor& in, Tensor& out, const TensorShape& shape) {
|
||||
auto out_data = out.MutableData<MLFloat16>();
|
||||
ptrdiff_t shape_size = gsl::narrow<ptrdiff_t>(shape.Size());
|
||||
auto in_vector = ConstEigenVectorMap<float>(in.Data<float>(), shape_size);
|
||||
auto output_vector = EigenVectorMap<Eigen::half>(static_cast<Eigen::half*>(static_cast<void*>(out_data)),
|
||||
shape_size);
|
||||
output_vector = in_vector.template cast<Eigen::half>();
|
||||
}
|
||||
using EnabledSrcTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0);
|
||||
using EnabledDstTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Cast, Output, 0);
|
||||
|
||||
template <>
|
||||
inline void CastData<MLFloat16, float>(const Tensor& in, Tensor& out, const TensorShape& shape) {
|
||||
auto out_data = out.MutableData<float>();
|
||||
auto in_data = in.Data<MLFloat16>();
|
||||
ptrdiff_t shape_size = gsl::narrow<ptrdiff_t>(shape.Size());
|
||||
#if defined(_M_AMD64)
|
||||
MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size);
|
||||
#else
|
||||
auto in_vector = ConstEigenVectorMap<Eigen::half>(static_cast<const Eigen::half*>(static_cast<const void*>(in_data)),
|
||||
shape_size);
|
||||
auto output_vector = EigenVectorMap<float>(out_data, shape_size);
|
||||
output_vector = in_vector.template cast<float>();
|
||||
#endif
|
||||
}
|
||||
using IndirectCastTypes = TypeList<MLFloat16, BFloat16>;
|
||||
|
||||
template <>
|
||||
inline void CastData<float, BFloat16>(const Tensor& in, Tensor& out, const TensorShape& shape) {
|
||||
auto out_data = out.template MutableData<BFloat16>();
|
||||
ptrdiff_t shape_size = gsl::narrow<ptrdiff_t>(shape.Size());
|
||||
auto in_vector = ConstEigenVectorMap<float>(in.template Data<float>(), shape_size);
|
||||
auto output_vector = EigenVectorMap<BFloat16>(out_data, shape_size);
|
||||
output_vector = in_vector.template cast<BFloat16>();
|
||||
}
|
||||
template <typename Type>
|
||||
using IsDirectCastType = mp_not<mp_contains<IndirectCastTypes, Type>>;
|
||||
|
||||
template <>
|
||||
inline void CastData<BFloat16, float>(const Tensor& in, Tensor& out, const TensorShape& shape) {
|
||||
auto out_data = out.template MutableData<float>();
|
||||
auto in_data = in.template Data<BFloat16>();
|
||||
ptrdiff_t shape_size = gsl::narrow<ptrdiff_t>(shape.Size());
|
||||
auto in_vector = ConstEigenVectorMap<BFloat16>(in_data, shape_size);
|
||||
auto output_vector = EigenVectorMap<float>(out_data, shape_size);
|
||||
output_vector = in_vector.unaryExpr([](BFloat16 val) { return val.ToFloat(); });
|
||||
}
|
||||
#endif
|
||||
template <typename... Types>
|
||||
using AreAllDirectCastTypes = mp_all<IsDirectCastType<Types>...>;
|
||||
|
||||
#ifdef CAST_STRING_ENABLED
|
||||
// string cast helpers
|
||||
|
||||
// handle floating point input separately
|
||||
template <typename SrcType>
|
||||
typename std::enable_if<std::is_floating_point<SrcType>::value, void>::type
|
||||
CastToStringData(const Tensor& in, Tensor& out, const TensorShape& shape) {
|
||||
const int64_t len = shape.Size();
|
||||
const auto input_data = in.DataAsSpan<SrcType>();
|
||||
auto output_data = out.MutableDataAsSpan<std::string>();
|
||||
|
||||
for (int i = 0; i < len; ++i) {
|
||||
if (std::isnan(input_data[i])) {
|
||||
output_data[i] = "NaN";
|
||||
} else if (std::isinf(input_data[i])) {
|
||||
if (input_data[i] < std::numeric_limits<SrcType>::lowest()) {
|
||||
output_data[i] = "-INF";
|
||||
} else {
|
||||
output_data[i] = "INF";
|
||||
}
|
||||
CastToString(const SrcType& input, std::string& output) {
|
||||
if (std::isnan(input)) {
|
||||
output = "NaN";
|
||||
} else if (std::isinf(input)) {
|
||||
if (input < std::numeric_limits<SrcType>::lowest()) {
|
||||
output = "-INF";
|
||||
} else {
|
||||
// setprecision to 8 to match numpy default behavior
|
||||
std::ostringstream convert;
|
||||
convert << std::setprecision(8) << input_data[i];
|
||||
output_data[i] = convert.str();
|
||||
output = "INF";
|
||||
}
|
||||
} else {
|
||||
// setprecision to 8 to match numpy default behavior
|
||||
std::ostringstream convert;
|
||||
convert << std::setprecision(8) << input;
|
||||
output = convert.str();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcType>
|
||||
typename std::enable_if<!std::is_floating_point<SrcType>::value, void>::type
|
||||
CastToStringData(const Tensor& in, Tensor& out, const TensorShape& shape) {
|
||||
const int64_t len = shape.Size();
|
||||
const auto input_data = in.DataAsSpan<SrcType>();
|
||||
auto output_data = out.MutableDataAsSpan<std::string>();
|
||||
|
||||
for (int i = 0; i < len; ++i) {
|
||||
std::ostringstream convert;
|
||||
convert << input_data[i];
|
||||
output_data[i] = convert.str();
|
||||
}
|
||||
CastToString(const SrcType& input, std::string& output) {
|
||||
std::ostringstream convert;
|
||||
convert << input;
|
||||
output = convert.str();
|
||||
}
|
||||
|
||||
template <typename DstType>
|
||||
void CastFromStringData(const Tensor& in, Tensor& out, const TensorShape& shape) {
|
||||
const int64_t len = shape.Size();
|
||||
if (std::is_same<DstType, float>::value) {
|
||||
auto* mutable_data = out.MutableData<float>();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
mutable_data[i] = std::stof(in.Data<std::string>()[i]);
|
||||
}
|
||||
} else if (std::is_same<DstType, double>::value) {
|
||||
auto* mutable_data = out.MutableData<double>();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
mutable_data[i] = std::stod(in.Data<std::string>()[i]);
|
||||
}
|
||||
} else if (std::is_same<DstType, int8_t>::value) {
|
||||
auto* mutable_data = out.MutableData<int8_t>();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
int temp_i = std::stoi(in.Data<std::string>()[i]);
|
||||
mutable_data[i] = static_cast<int8_t>(temp_i);
|
||||
}
|
||||
} else if (std::is_same<DstType, uint8_t>::value) {
|
||||
auto* mutable_data = out.MutableData<uint8_t>();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
unsigned long temp_ui = std::stoul(in.Data<std::string>()[i]);
|
||||
mutable_data[i] = static_cast<uint8_t>(temp_ui);
|
||||
}
|
||||
} else if (std::is_same<DstType, int16_t>::value) {
|
||||
auto* mutable_data = out.MutableData<int16_t>();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
int temp_i = std::stoi(in.Data<std::string>()[i]);
|
||||
mutable_data[i] = static_cast<int16_t>(temp_i);
|
||||
}
|
||||
} else if (std::is_same<DstType, uint16_t>::value) {
|
||||
auto* mutable_data = out.MutableData<uint16_t>();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
unsigned long temp_ui = std::stoul(in.Data<std::string>()[i]);
|
||||
mutable_data[i] = static_cast<uint16_t>(temp_ui);
|
||||
}
|
||||
} else if (std::is_same<DstType, int32_t>::value) {
|
||||
auto* mutable_data = out.MutableData<int32_t>();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
mutable_data[i] = std::stol(in.Data<std::string>()[i]);
|
||||
}
|
||||
} else if (std::is_same<DstType, uint32_t>::value) {
|
||||
auto* mutable_data = out.MutableData<uint32_t>();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
mutable_data[i] = std::stoul(in.Data<std::string>()[i]);
|
||||
}
|
||||
} else if (std::is_same<DstType, int64_t>::value) {
|
||||
auto* mutable_data = out.MutableData<int64_t>();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
mutable_data[i] = std::stoll(in.Data<std::string>()[i]);
|
||||
}
|
||||
} else if (std::is_same<DstType, uint64_t>::value) {
|
||||
auto* mutable_data = out.MutableData<uint64_t>();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
mutable_data[i] = std::stoull(in.Data<std::string>()[i]);
|
||||
}
|
||||
} else {
|
||||
#ifdef ORT_NO_RTTI
|
||||
ORT_THROW("Unsupported type in cast op");
|
||||
#else
|
||||
ORT_THROW("Unsupported type in cast op: from String to ", typeid(DstType).name());
|
||||
#endif
|
||||
typename std::enable_if<std::is_floating_point<DstType>::value, void>::type
|
||||
CastFromString(const std::string& input, DstType& output) {
|
||||
static_assert(sizeof(DstType) <= sizeof(double),
|
||||
"largest supported floating point type is double");
|
||||
output = gsl::narrow_cast<DstType>(std::stod(input));
|
||||
}
|
||||
|
||||
template <typename DstType>
|
||||
typename std::enable_if<std::is_integral<DstType>::value && std::is_unsigned<DstType>::value, void>::type
|
||||
CastFromString(const std::string& input, DstType& output) {
|
||||
static_assert(sizeof(DstType) <= sizeof(unsigned long long),
|
||||
"largest supported unsigned integral type is unsigned long long");
|
||||
output = gsl::narrow_cast<DstType>(std::stoull(input));
|
||||
}
|
||||
|
||||
template <typename DstType>
|
||||
typename std::enable_if<std::is_integral<DstType>::value && std::is_signed<DstType>::value, void>::type
|
||||
CastFromString(const std::string& input, DstType& output) {
|
||||
static_assert(sizeof(DstType) <= sizeof(long long),
|
||||
"largest supported signed integral type is long long");
|
||||
output = gsl::narrow_cast<DstType>(std::stoll(input));
|
||||
}
|
||||
|
||||
// generic scalar X -> Y
|
||||
template <typename SrcType, typename DstType>
|
||||
struct ScalarDirectCaster {
|
||||
void Cast(const SrcType& in, DstType& out) const {
|
||||
out = static_cast<DstType>(in);
|
||||
}
|
||||
};
|
||||
|
||||
// scalar X -> string
|
||||
template <typename SrcType>
|
||||
struct ScalarDirectCaster<SrcType, std::string> {
|
||||
void Cast(const SrcType& in, std::string& out) const {
|
||||
CastToString<SrcType>(in, out);
|
||||
}
|
||||
};
|
||||
|
||||
// scalar string -> X
|
||||
template <typename DstType>
|
||||
struct ScalarDirectCaster<std::string, DstType> {
|
||||
void Cast(const std::string& in, DstType& out) const {
|
||||
CastFromString<DstType>(in, out);
|
||||
}
|
||||
};
|
||||
|
||||
// helper for indirect cast types
|
||||
template <typename SrcType, typename DstType, typename IntermediateType>
|
||||
struct ScalarIndirectCaster {
|
||||
void Cast(const SrcType& in, DstType& out) const {
|
||||
IntermediateType intermediate;
|
||||
ScalarDirectCaster<SrcType, IntermediateType>{}.Cast(in, intermediate);
|
||||
ScalarDirectCaster<IntermediateType, DstType>{}.Cast(intermediate, out);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SrcType, typename DstType, class Enable = void>
|
||||
struct ScalarCaster;
|
||||
|
||||
template <typename SrcType, typename DstType>
|
||||
struct ScalarCaster<
|
||||
SrcType, DstType,
|
||||
typename std::enable_if<AreAllDirectCastTypes<SrcType, DstType>::value>::type> {
|
||||
void Cast(const SrcType& in, DstType& out) const {
|
||||
ScalarDirectCaster<SrcType, DstType>{}.Cast(in, out);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SrcType, typename DstType>
|
||||
struct ScalarCaster<
|
||||
SrcType, DstType,
|
||||
typename std::enable_if<!AreAllDirectCastTypes<SrcType, DstType>::value>::type> {
|
||||
void Cast(const SrcType& in, DstType& out) const {
|
||||
ScalarIndirectCaster<SrcType, DstType, float>{}.Cast(in, out);
|
||||
}
|
||||
};
|
||||
|
||||
// generic tensor X -> Y
|
||||
template <typename SrcType, typename DstType>
|
||||
struct TensorCaster {
|
||||
void Cast(const Tensor& in, Tensor& out, const TensorShape& shape) const {
|
||||
const std::ptrdiff_t shape_size = gsl::narrow<std::ptrdiff_t>(shape.Size());
|
||||
const auto in_vector = ConstEigenVectorMap<SrcType>(in.Data<SrcType>(), shape_size);
|
||||
auto out_vector = EigenVectorMap<DstType>(out.MutableData<DstType>(), shape_size);
|
||||
out_vector = in_vector.unaryExpr([](const SrcType& in_scalar) {
|
||||
DstType out_scalar;
|
||||
ScalarCaster<SrcType, DstType>{}.Cast(in_scalar, out_scalar);
|
||||
return out_scalar;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SrcType, typename DstType>
|
||||
void CastStringTensor(const Tensor& in, Tensor& out, const TensorShape& shape) {
|
||||
static_assert(std::is_same<SrcType, std::string>::value || std::is_same<DstType, std::string>::value,
|
||||
"Either SrcType or DstType must be std::string.");
|
||||
const std::ptrdiff_t shape_size = gsl::narrow<std::ptrdiff_t>(shape.Size());
|
||||
const auto in_data = in.DataAsSpan<SrcType>();
|
||||
const auto out_data = out.MutableDataAsSpan<DstType>();
|
||||
for (std::ptrdiff_t i = 0; i < shape_size; ++i) {
|
||||
ScalarCaster<SrcType, DstType>{}.Cast(in_data[i], out_data[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
// tensor X -> string
|
||||
template <typename SrcType>
|
||||
struct TensorCaster<SrcType, std::string> {
|
||||
void Cast(const Tensor& in, Tensor& out, const TensorShape& shape) const {
|
||||
CastStringTensor<SrcType, std::string>(in, out, shape);
|
||||
}
|
||||
};
|
||||
|
||||
// tensor string -> X
|
||||
template <typename DstType>
|
||||
struct TensorCaster<std::string, DstType> {
|
||||
void Cast(const Tensor& in, Tensor& out, const TensorShape& shape) const {
|
||||
CastStringTensor<std::string, DstType>(in, out, shape);
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(_M_AMD64)
|
||||
// tensor MLFloat16 -> float
|
||||
template <>
|
||||
struct TensorCaster<MLFloat16, float> {
|
||||
void Cast(const Tensor& in, Tensor& out, const TensorShape& shape) const {
|
||||
auto out_data = out.MutableData<float>();
|
||||
auto in_data = in.Data<MLFloat16>();
|
||||
const size_t shape_size = gsl::narrow<size_t>(shape.Size());
|
||||
MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
class Cast final : public OpKernel {
|
||||
public:
|
||||
|
|
@ -201,46 +239,63 @@ class Cast final : public OpKernel {
|
|||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
template <typename TSrc>
|
||||
struct SrcDispatcher;
|
||||
|
||||
template <typename TSrc, typename TDest>
|
||||
struct Dispatcher;
|
||||
|
||||
template <typename T>
|
||||
struct StringDispatcher;
|
||||
|
||||
ONNX_NAMESPACE::TensorProto_DataType to_;
|
||||
};
|
||||
|
||||
const std::vector<MLDataType> castOpTypeConstraints{
|
||||
DataTypeImpl::GetTensorType<bool>(),
|
||||
DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<double>(),
|
||||
DataTypeImpl::GetTensorType<uint8_t>(),
|
||||
DataTypeImpl::GetTensorType<uint16_t>(),
|
||||
DataTypeImpl::GetTensorType<uint32_t>(),
|
||||
DataTypeImpl::GetTensorType<uint64_t>(),
|
||||
DataTypeImpl::GetTensorType<int8_t>(),
|
||||
DataTypeImpl::GetTensorType<int16_t>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>(),
|
||||
#ifdef CAST_FLOAT16_ENABLED
|
||||
DataTypeImpl::GetTensorType<MLFloat16>(),
|
||||
DataTypeImpl::GetTensorType<BFloat16>(),
|
||||
#endif
|
||||
#ifdef CAST_STRING_ENABLED
|
||||
DataTypeImpl::GetTensorType<std::string>(),
|
||||
#endif
|
||||
template <typename TSrc, typename TDst>
|
||||
struct Dispatcher {
|
||||
void operator()(const Tensor& src, Tensor& dst, const TensorShape& shape) {
|
||||
TensorCaster<TSrc, TDst>{}.Cast(src, dst, shape);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TSrc>
|
||||
struct SrcDispatcher {
|
||||
void operator()(int32_t to, const Tensor& src, Tensor& dst, const TensorShape& shape) {
|
||||
using DstTypes = mp_remove_if_q<EnabledDstTypes, mp_bind_front<std::is_same, TSrc>>;
|
||||
utils::MLTypeCallDispatcherFromTypeList<DstTypes> dispatcher{to};
|
||||
dispatcher.template InvokeWithLeadingTemplateArgs<Dispatcher, TypeList<TSrc>>(src, dst, shape);
|
||||
}
|
||||
};
|
||||
|
||||
Status Cast::Compute(OpKernelContext* context) const {
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
const TensorShape& shape = X->Shape();
|
||||
Tensor* Y = context->Output(0, shape);
|
||||
|
||||
if (shape.Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const auto from = X->GetElementType();
|
||||
|
||||
if (from == to_) {
|
||||
// will copy if X and Y have different buffers
|
||||
CopyCpuTensor(X, Y);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
utils::MLTypeCallDispatcherFromTypeList<EnabledSrcTypes> dispatcher{from};
|
||||
dispatcher.Invoke<SrcDispatcher>(to_, *X, *Y, shape);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const std::vector<MLDataType> castSrcTypeConstraints =
|
||||
BuildKernelDefConstraintsFunctorFromTypeList<EnabledSrcTypes>{}();
|
||||
|
||||
const std::vector<MLDataType> castDstTypeConstraints =
|
||||
BuildKernelDefConstraintsFunctorFromTypeList<EnabledDstTypes>{}();
|
||||
|
||||
} // namespace
|
||||
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
Cast,
|
||||
6,
|
||||
12,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T1", castOpTypeConstraints)
|
||||
.TypeConstraint("T2", castOpTypeConstraints)
|
||||
.TypeConstraint("T1", castSrcTypeConstraints)
|
||||
.TypeConstraint("T2", castDstTypeConstraints)
|
||||
.MayInplace(0, 0), // allocation planner will check input and output sizes match before inplacing
|
||||
Cast);
|
||||
|
||||
|
|
@ -248,152 +303,9 @@ ONNX_CPU_OPERATOR_KERNEL(
|
|||
Cast,
|
||||
13,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T1", castOpTypeConstraints)
|
||||
.TypeConstraint("T2", castOpTypeConstraints)
|
||||
.TypeConstraint("T1", castSrcTypeConstraints)
|
||||
.TypeConstraint("T2", castDstTypeConstraints)
|
||||
.MayInplace(0, 0), // allocation planner will check input and output sizes match before inplacing
|
||||
Cast);
|
||||
|
||||
// default dispatch
|
||||
template <typename TSrc, typename TDst>
|
||||
struct Cast::Dispatcher {
|
||||
void operator()(const Tensor& src, Tensor& dst, const TensorShape& shape) {
|
||||
CastData<TSrc, TDst>(src, dst, shape);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TSrc>
|
||||
struct Cast::SrcDispatcher {
|
||||
void operator()(int32_t to, const Tensor& src, Tensor& dst, const TensorShape& shape) {
|
||||
utils::MLTypeCallDispatcherWithCarriedType<TSrc, Cast::Dispatcher,
|
||||
float, double, int8_t, uint8_t, int16_t, uint16_t,
|
||||
int32_t, uint32_t, int64_t, uint64_t, bool>
|
||||
t_disp(to);
|
||||
|
||||
t_disp.Invoke(src, dst, shape);
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef CAST_STRING_ENABLED
|
||||
template <typename T>
|
||||
struct Cast::StringDispatcher {
|
||||
void operator()(bool to_string, const Tensor& src, Tensor& dst, const TensorShape& shape) {
|
||||
if (to_string) {
|
||||
CastToStringData<T>(src, dst, shape);
|
||||
} else {
|
||||
CastFromStringData<T>(src, dst, shape);
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
Status Cast::Compute(OpKernelContext* context) const {
|
||||
Status status = Status::OK();
|
||||
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
const TensorShape& shape = X->Shape();
|
||||
Tensor* Y = context->Output(0, shape);
|
||||
if (shape.Size() == 0) {
|
||||
return status;
|
||||
}
|
||||
|
||||
auto from = X->GetElementType();
|
||||
|
||||
if (from == to_) {
|
||||
// will copy if X and Y have different buffers
|
||||
CopyCpuTensor(X, Y);
|
||||
return status;
|
||||
}
|
||||
|
||||
#ifdef CAST_STRING_ENABLED
|
||||
// special case strings
|
||||
if (from == ONNX_NAMESPACE::TensorProto_DataType_STRING ||
|
||||
to_ == ONNX_NAMESPACE::TensorProto_DataType_STRING) {
|
||||
bool to_string = to_ == ONNX_NAMESPACE::TensorProto_DataType_STRING;
|
||||
|
||||
utils::MLTypeCallDispatcher<StringDispatcher,
|
||||
float, double,
|
||||
#ifdef CAST_FLOAT16_ENABLED
|
||||
MLFloat16, /*BFloat16,*/
|
||||
#endif
|
||||
int8_t, uint8_t, int16_t, uint16_t, int32_t, uint32_t, int64_t, uint64_t,
|
||||
bool>
|
||||
t_disp(to_string ? from : to_);
|
||||
|
||||
t_disp.Invoke(to_string, *X, *Y, shape);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
auto do_cast = [](int32_t from, int32_t to, const Tensor& src, Tensor& dst, const TensorShape& shape) {
|
||||
utils::MLTypeCallDispatcher<SrcDispatcher,
|
||||
float, double, // MLFloat16 is special cased below
|
||||
int8_t, uint8_t, int16_t, uint16_t, int32_t, uint32_t, int64_t, uint64_t, bool>
|
||||
t_disp(from);
|
||||
|
||||
t_disp.Invoke(to, src, dst, shape);
|
||||
};
|
||||
|
||||
#ifdef CAST_FLOAT16_ENABLED
|
||||
// MLFloat16 needs special handling
|
||||
if (from == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
|
||||
if (to_ == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
CastData<MLFloat16, float>(*X, *Y, shape);
|
||||
} else {
|
||||
// need to cast to float first in a temporary buffer
|
||||
AllocatorPtr allocator;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
|
||||
auto tmp_buffer = IAllocator::MakeUniquePtr<float>(allocator, gsl::narrow<size_t>(shape.Size()));
|
||||
Tensor tmp_tensor(DataTypeImpl::GetType<float>(), shape, tmp_buffer.get(), allocator->Info());
|
||||
|
||||
CastData<MLFloat16, float>(*X, tmp_tensor, shape);
|
||||
do_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, to_, tmp_tensor, *Y, shape);
|
||||
}
|
||||
} else if (to_ == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
|
||||
if (from == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
CastData<float, MLFloat16>(*X, *Y, shape);
|
||||
} else {
|
||||
// need to cast to float first in a temporary buffer
|
||||
AllocatorPtr allocator;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
|
||||
auto tmp_buffer = IAllocator::MakeUniquePtr<float>(allocator, gsl::narrow<size_t>(shape.Size()));
|
||||
Tensor tmp_tensor(DataTypeImpl::GetType<float>(), shape, tmp_buffer.get(), allocator->Info());
|
||||
|
||||
do_cast(from, ONNX_NAMESPACE::TensorProto_DataType_FLOAT, *X, tmp_tensor, shape);
|
||||
CastData<float, MLFloat16>(tmp_tensor, *Y, shape);
|
||||
}
|
||||
} else if (from == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
|
||||
if (to_ == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
CastData<BFloat16, float>(*X, *Y, shape);
|
||||
} else {
|
||||
// need to cast to float first in a temporary buffer
|
||||
AllocatorPtr allocator;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
|
||||
auto tmp_buffer = IAllocator::MakeUniquePtr<float>(allocator, gsl::narrow<size_t>(shape.Size()));
|
||||
Tensor tmp_tensor(DataTypeImpl::GetType<float>(), shape, tmp_buffer.get(), allocator->Info());
|
||||
|
||||
CastData<BFloat16, float>(*X, tmp_tensor, shape);
|
||||
do_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, to_, tmp_tensor, *Y, shape);
|
||||
}
|
||||
} else if (to_ == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
|
||||
if (from == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
CastData<float, BFloat16>(*X, *Y, shape);
|
||||
} else {
|
||||
// need to cast to float first in a temporary buffer
|
||||
AllocatorPtr allocator;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
|
||||
auto tmp_buffer = IAllocator::MakeUniquePtr<float>(allocator, gsl::narrow<size_t>(shape.Size()));
|
||||
Tensor tmp_tensor(DataTypeImpl::GetType<float>(), shape, tmp_buffer.get(), allocator->Info());
|
||||
|
||||
do_cast(from, ONNX_NAMESPACE::TensorProto_DataType_FLOAT, *X, tmp_tensor, shape);
|
||||
CastData<float, BFloat16>(tmp_tensor, *Y, shape);
|
||||
}
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
do_cast(from, to_, *X, *Y, shape);
|
||||
}
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,18 +1,28 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/data_types_internal.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/providers/op_kernel_type_control.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#IsInf
|
||||
|
||||
namespace op_kernel_type_control {
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
|
||||
kCpuExecutionProvider, kOnnxDomain, IsInf, Input, 0,
|
||||
float, double);
|
||||
} // namespace op_kernel_type_control
|
||||
|
||||
class IsInf final : public OpKernel {
|
||||
public:
|
||||
using EnabledTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, IsInf, Input, 0);
|
||||
|
||||
explicit IsInf(const OpKernelInfo& info);
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
|
|
@ -25,8 +35,8 @@ ONNX_CPU_OPERATOR_KERNEL(
|
|||
IsInf,
|
||||
10,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T1", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<double>()})
|
||||
.TypeConstraint(
|
||||
"T1", BuildKernelDefConstraintsFunctorFromTypeList<IsInf::EnabledTypes>{}())
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
|
||||
IsInf);
|
||||
|
||||
|
|
@ -39,32 +49,34 @@ IsInf::IsInf(const OpKernelInfo& info) : OpKernel(info) {
|
|||
|
||||
namespace isinf_internal {
|
||||
template <class T>
|
||||
void ComputeImpl(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) {
|
||||
const auto total_items = X.Shape().Size();
|
||||
auto output_data = Y.template MutableData<bool>();
|
||||
struct ComputeDispatchTarget {
|
||||
void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const {
|
||||
const auto total_items = X.Shape().Size();
|
||||
auto output_data = Y.template MutableData<bool>();
|
||||
|
||||
if (detect_positive && detect_negative) {
|
||||
EigenMap<bool>(Y) = EigenMap<T>(X).array().isInf();
|
||||
} else if (detect_positive) {
|
||||
auto input_data = X.template Data<T>();
|
||||
auto end_data = input_data + total_items;
|
||||
std::transform(
|
||||
input_data, end_data, output_data, [](T v) {
|
||||
return (v == std::numeric_limits<T>::infinity());
|
||||
});
|
||||
if (detect_positive && detect_negative) {
|
||||
EigenMap<bool>(Y) = EigenMap<T>(X).array().isInf();
|
||||
} else if (detect_positive) {
|
||||
auto input_data = X.template Data<T>();
|
||||
auto end_data = input_data + total_items;
|
||||
std::transform(
|
||||
input_data, end_data, output_data, [](T v) {
|
||||
return (v == std::numeric_limits<T>::infinity());
|
||||
});
|
||||
|
||||
} else if (detect_negative) {
|
||||
auto input_data = X.template Data<T>();
|
||||
auto end_data = input_data + total_items;
|
||||
std::transform(
|
||||
input_data, end_data, output_data, [](T v) {
|
||||
return (v == -std::numeric_limits<T>::infinity());
|
||||
});
|
||||
} else {
|
||||
// all false
|
||||
memset(output_data, false, total_items);
|
||||
} else if (detect_negative) {
|
||||
auto input_data = X.template Data<T>();
|
||||
auto end_data = input_data + total_items;
|
||||
std::transform(
|
||||
input_data, end_data, output_data, [](T v) {
|
||||
return (v == -std::numeric_limits<T>::infinity());
|
||||
});
|
||||
} else {
|
||||
// all false
|
||||
memset(output_data, false, total_items);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace isinf_internal
|
||||
|
||||
Status IsInf::Compute(OpKernelContext* context) const {
|
||||
|
|
@ -75,14 +87,8 @@ Status IsInf::Compute(OpKernelContext* context) const {
|
|||
|
||||
using namespace isinf_internal;
|
||||
|
||||
if (X.IsDataType<float>()) {
|
||||
ComputeImpl<float>(X, Y, detect_positive_ != 0, detect_negative_ != 0);
|
||||
} else if (X.IsDataType<double>()) {
|
||||
ComputeImpl<double>(X, Y, detect_positive_ != 0, detect_negative_ != 0);
|
||||
} else {
|
||||
// should not reach this as no kernel is registered for this condition to be triggered - just an additional safety check
|
||||
ORT_THROW("Data type X must be float or double, but instead got ", X.DataType());
|
||||
}
|
||||
utils::MLTypeCallDispatcherFromTypeList<EnabledTypes> dispatcher{X.GetElementType()};
|
||||
dispatcher.Invoke<ComputeDispatchTarget>(X, Y, detect_positive_ != 0, detect_negative_ != 0);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ Status Pad<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
}
|
||||
}
|
||||
|
||||
T raw_value(0);
|
||||
T raw_value{};
|
||||
const Tensor* value_tensor = ctx->Input<Tensor>(2);
|
||||
if (nullptr != value_tensor) {
|
||||
ORT_ENFORCE(utils::IsPrimitiveDataType<T>(value_tensor->DataType()) &&
|
||||
|
|
|
|||
232
onnxruntime/core/providers/op_kernel_type_control.h
Normal file
232
onnxruntime/core/providers/op_kernel_type_control.h
Normal file
|
|
@ -0,0 +1,232 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <tuple>
|
||||
|
||||
#include "boost/mp11.hpp"
|
||||
|
||||
#include "core/common/type_list.h"
|
||||
#include "core/framework/data_types.h"
|
||||
|
||||
/**
|
||||
* These utilities provide a way to control what types are enabled for an Op kernel implementation.
|
||||
*
|
||||
* At a high level, we have the notion of supported, allowed, and enabled types.
|
||||
* - Supported types are the types that the Op kernel implementation supports by default.
|
||||
* - Allowed types are the types for which support is requested (for example, by external configuration).
|
||||
* - Enabled types are the types that are supported in the actual, compiled implementation. They are obtained from the
|
||||
* intersection of supported and allowed types.
|
||||
*
|
||||
* The types are associated with an Op kernel argument. It is also possible to specify a global list of allowed types.
|
||||
*
|
||||
* Use of these utilities is optional. They are useful for cases where one registered Op kernel handles multiple types.
|
||||
*
|
||||
* See the macros below for usage details.
|
||||
*/
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace op_kernel_type_control {
|
||||
|
||||
enum class OpArgDirection {
|
||||
Input,
|
||||
Output
|
||||
};
|
||||
|
||||
using OpArgIndex = size_t;
|
||||
|
||||
namespace tags {
|
||||
|
||||
// a tag that identifies the target (Op argument) of the specified types
|
||||
template <typename OpTag, OpArgDirection ArgDirection, OpArgIndex ArgIndex>
|
||||
struct OpArg {};
|
||||
|
||||
// a tag that indicates the supported types for a particular Op argument, identified by OpArgTag,
|
||||
// for a kernel in a particular provider, identified by ProviderTag
|
||||
template <typename OpArgTag, typename ProviderTag>
|
||||
struct Supported {};
|
||||
|
||||
// a tag that indicates the allowed types for a particular Op argument, identified by OpArgTag
|
||||
template <typename OpArgTag>
|
||||
struct Allowed {};
|
||||
|
||||
// a tag that indicates the globally allowed types
|
||||
struct GlobalAllowed {};
|
||||
|
||||
} // namespace tags
|
||||
|
||||
// optionally holds a list of types associated with a tag class
|
||||
// if types are defined, the data member 'types' should contain them in a type list
|
||||
// otherwise, if no types are defined (distinct from an empty list of types), there should be no data member 'types'
|
||||
// see the tags in onnxruntime::op_kernel_type_control::tags for intended uses
|
||||
template <typename Tag>
|
||||
struct TypesHolder {};
|
||||
|
||||
/**
|
||||
* Provides a type list of enabled types via the 'types' data member.
|
||||
* Enabled types are the set intersection of supported and allowed types.
|
||||
*
|
||||
* @tparam SupportedTypesHolder A 'TypesHolder' with a list of supported types.
|
||||
* @tparam AllowedTypesHolders A list of 'TypesHolder's each with an optional list of allowed types.
|
||||
*/
|
||||
template <typename SupportedTypesHolder, typename AllowedTypesHolders>
|
||||
struct EnabledTypes {
|
||||
private:
|
||||
static_assert(boost::mp11::mp_is_list<AllowedTypesHolders>::value,
|
||||
"AllowedTypesHolders must be a type list.");
|
||||
|
||||
template <typename T>
|
||||
using GetTypesMember = typename T::types;
|
||||
|
||||
// checks whether T has data member 'types'
|
||||
template <typename T>
|
||||
using HasTypesMember = boost::mp11::mp_valid<GetTypesMember, T>;
|
||||
|
||||
static_assert(HasTypesMember<SupportedTypesHolder>::value,
|
||||
"SupportedTypesHolder must have a 'types' data member.");
|
||||
|
||||
// the allowed type lists to consider
|
||||
// for each element of AllowedTypesHolders, get and include a 'types' data member if present
|
||||
using AllowedTypesMembers =
|
||||
boost::mp11::mp_transform<
|
||||
GetTypesMember,
|
||||
boost::mp11::mp_filter<
|
||||
HasTypesMember,
|
||||
AllowedTypesHolders>>;
|
||||
|
||||
// collect supported and allowed type lists
|
||||
using TypeListsToConsider =
|
||||
boost::mp11::mp_push_front<AllowedTypesMembers, GetTypesMember<SupportedTypesHolder>>;
|
||||
|
||||
static_assert(boost::mp11::mp_all_of<TypeListsToConsider, boost::mp11::mp_is_list>::value,
|
||||
"All 'types' data members must be type lists.");
|
||||
|
||||
// converts type list L into a type set (type list with unique elements)
|
||||
template <typename L>
|
||||
using MakeSet =
|
||||
boost::mp11::mp_apply<
|
||||
boost::mp11::mp_set_push_back,
|
||||
boost::mp11::mp_append<TypeList<TypeList<>>, L>>;
|
||||
|
||||
// type lists converted to type sets
|
||||
using TypeSetsToConsider = boost::mp11::mp_transform<MakeSet, TypeListsToConsider>;
|
||||
|
||||
public:
|
||||
using types = boost::mp11::mp_apply<boost::mp11::mp_set_intersection, TypeSetsToConsider>;
|
||||
};
|
||||
|
||||
} // namespace op_kernel_type_control
|
||||
} // namespace onnxruntime
|
||||
|
||||
// INTERNAL
|
||||
// the class name of a tag type identifying an Op for the purposes of type control
|
||||
#define ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_TAG_CLASS_NAME(OpDomain, OpName) \
|
||||
TypeControl_##OpDomain##_##OpName##_OpTag
|
||||
|
||||
// INTERNAL
|
||||
// the class name of a tag type identifying a provider for the purposes of type control
|
||||
#define ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider) \
|
||||
TypeControl_##OpProvider##_ProviderTag
|
||||
|
||||
// INTERNAL
|
||||
// a tag type identifying an Op argument
|
||||
#define ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_KERNEL_ARG_TAG( \
|
||||
OpDomain, OpName, ArgDirection, ArgIndex) \
|
||||
::onnxruntime::op_kernel_type_control::tags::OpArg< \
|
||||
::onnxruntime::op_kernel_type_control:: \
|
||||
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_TAG_CLASS_NAME(OpDomain, OpName), \
|
||||
::onnxruntime::op_kernel_type_control::OpArgDirection::ArgDirection, \
|
||||
ArgIndex>
|
||||
|
||||
// public macros
|
||||
|
||||
/**
|
||||
* Specifies a supported set of types for a given Op kernel argument.
|
||||
* This should be specified with the Op kernel implementation.
|
||||
*
|
||||
* Note: This should be called from the onnxruntime::op_kernel_type_control namespace.
|
||||
*
|
||||
* @param OpProvider The Op provider.
|
||||
* @param OpDomain The Op domain.
|
||||
* @param OpName The Op name.
|
||||
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
|
||||
* @param ArgIndex Index of the given Op kernel argument.
|
||||
* @param ... The types.
|
||||
*/
|
||||
#define ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES( \
|
||||
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex, ...) \
|
||||
class ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_TAG_CLASS_NAME(OpDomain, OpName); \
|
||||
class ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider); \
|
||||
template <> \
|
||||
struct TypesHolder< \
|
||||
::onnxruntime::op_kernel_type_control::tags::Supported< \
|
||||
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_KERNEL_ARG_TAG(OpDomain, OpName, ArgDirection, ArgIndex), \
|
||||
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider)>> { \
|
||||
using types = ::onnxruntime::TypeList<__VA_ARGS__>; \
|
||||
};
|
||||
|
||||
/**
|
||||
* TypeList type with the enabled types for a given Op kernel argument.
|
||||
*
|
||||
* @param OpProvider The Op provider.
|
||||
* @param OpDomain The Op domain.
|
||||
* @param OpName The Op name.
|
||||
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
|
||||
* @param ArgIndex Index of the given Op kernel argument.
|
||||
*/
|
||||
#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( \
|
||||
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex) \
|
||||
::onnxruntime::op_kernel_type_control::EnabledTypes< \
|
||||
::onnxruntime::op_kernel_type_control::TypesHolder< \
|
||||
::onnxruntime::op_kernel_type_control::tags::Supported< \
|
||||
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_KERNEL_ARG_TAG(OpDomain, OpName, ArgDirection, ArgIndex), \
|
||||
::onnxruntime::op_kernel_type_control:: \
|
||||
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_PROVIDER_TAG_CLASS_NAME(OpProvider)>>, \
|
||||
::onnxruntime::TypeList< \
|
||||
::onnxruntime::op_kernel_type_control::TypesHolder< \
|
||||
::onnxruntime::op_kernel_type_control::tags::Allowed< \
|
||||
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_KERNEL_ARG_TAG(OpDomain, OpName, ArgDirection, ArgIndex)>>, \
|
||||
::onnxruntime::op_kernel_type_control::TypesHolder< \
|
||||
::onnxruntime::op_kernel_type_control::tags::GlobalAllowed>>>::types
|
||||
|
||||
/**
|
||||
* std::tuple type with the enabled types for a given Op kernel argument.
|
||||
*
|
||||
* @param OpProvider The Op provider.
|
||||
* @param OpDomain The Op domain.
|
||||
* @param OpName The Op name.
|
||||
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
|
||||
* @param ArgIndex Index of the given Op kernel argument.
|
||||
*/
|
||||
#define ORT_OP_KERNEL_ARG_ENABLED_TYPE_TUPLE( \
|
||||
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex) \
|
||||
::boost::mp11::mp_rename< \
|
||||
ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( \
|
||||
OpProvider, OpDomain, OpName, ArgDirection, ArgIndex, SupportedTypeList), \
|
||||
std::tuple>
|
||||
|
||||
/**
|
||||
* Usage example:
|
||||
*
|
||||
* In MyProvider provider's implementation of MyOp kernel:
|
||||
*
|
||||
* // specify supported types, i.e., the full set of types that can be enabled
|
||||
* ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
|
||||
* MyProvider, DomainContainingMyOp, MyOp, Input, 0,
|
||||
* int, float, double);
|
||||
*
|
||||
* // get enabled types
|
||||
* using MyOpFirstInputEnabledTypes =
|
||||
* ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(MyProvider, DomainContainingMyOp, MyOp, Input, 0)
|
||||
*
|
||||
* ...
|
||||
*
|
||||
* // in the implementation, we can dispatch to the enabled types
|
||||
* utils::MLTypeCallDispatcherFromTypeList<MyOpFirstInputEnabledTypes> dispatcher{firstInputRuntimeType};
|
||||
* ...
|
||||
*/
|
||||
|
||||
// all allowed type specifications should be contained in the following file
|
||||
#include "core/providers/op_kernel_type_control_overrides.inc"
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace op_kernel_type_control {
|
||||
|
||||
/**
|
||||
* Specifies an allowed set of types for a given Op kernel argument.
|
||||
* This can optionally be specified to further limit the enabled types.
|
||||
*
|
||||
* Note: This should be called from the onnxruntime::op_kernel_type_control namespace.
|
||||
*
|
||||
* @param OpDomain The Op domain.
|
||||
* @param OpName The Op name.
|
||||
* @param ArgDirection Direction of the given Op kernel argument - Input or Output.
|
||||
* @param ArgIndex Index of the given Op kernel argument.
|
||||
* @param ... The types.
|
||||
*/
|
||||
#define ORT_SPECIFY_OP_KERNEL_ARG_ALLOWED_TYPES( \
|
||||
OpDomain, OpName, ArgDirection, ArgIndex, ...) \
|
||||
class ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_TAG_CLASS_NAME(OpDomain, OpName); \
|
||||
template <> \
|
||||
struct TypesHolder< \
|
||||
::onnxruntime::op_kernel_type_control::tags::Allowed< \
|
||||
ORT_OP_KERNEL_TYPE_CTRL_INTERNAL_OP_KERNEL_ARG_TAG( \
|
||||
OpDomain, OpName, ArgDirection, ArgIndex)>> { \
|
||||
using types = ::onnxruntime::TypeList<__VA_ARGS__>; \
|
||||
};
|
||||
|
||||
/**
|
||||
* Specifies an allowed set of types globally (applicable to any Op kernel argument).
|
||||
* This can optionally be specified to further limit the enabled types.
|
||||
*
|
||||
* Note: This should be called from the onnxruntime::op_kernel_type_control namespace.
|
||||
*
|
||||
* @param ... The types.
|
||||
*/
|
||||
#define ORT_SPECIFY_OP_KERNEL_GLOBAL_ALLOWED_TYPES(...) \
|
||||
template <> \
|
||||
struct TypesHolder< \
|
||||
::onnxruntime::op_kernel_type_control::tags::GlobalAllowed> { \
|
||||
using types = ::onnxruntime::TypeList<__VA_ARGS__>; \
|
||||
};
|
||||
|
||||
// Examples:
|
||||
// Specify allowed types per Op kernel arg:
|
||||
// ORT_SPECIFY_OP_KERNEL_ARG_ALLOWED_TYPES(kOnnxDomain, Cast, Input, 0, float, int64_t);
|
||||
// ORT_SPECIFY_OP_KERNEL_ARG_ALLOWED_TYPES(kOnnxDomain, Cast, Output, 0, float, int64_t);
|
||||
// Specify allowed types globally:
|
||||
// ORT_SPECIFY_OP_KERNEL_GLOBAL_ALLOWED_TYPES(float, double, int32_t)
|
||||
|
||||
// specify allowed types here
|
||||
|
||||
// @@insertion_point_begin(allowed_types)@@
|
||||
// @@insertion_point_end(allowed_types)@@
|
||||
|
||||
#undef ORT_SPECIFY_OP_KERNEL_ARG_ALLOWED_TYPES
|
||||
#undef ORT_SPECIFY_OP_KERNEL_GLOBAL_ALLOWED_TYPES
|
||||
|
||||
} // namespace op_kernel_type_control
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -126,7 +126,7 @@ class RandomValueGenerator {
|
|||
|
||||
template <class T>
|
||||
inline std::vector<T> FillZeros(const std::vector<int64_t>& dims) {
|
||||
std::vector<T> val(detail::SizeFromDims(dims), T(0));
|
||||
std::vector<T> val(detail::SizeFromDims(dims), T{});
|
||||
return val;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ void VerifyTensorProtoFileData(const PathString& tensor_proto_path, gsl::span<co
|
|||
|
||||
std::vector<T> actual_data{};
|
||||
actual_data.resize(expected_data.size());
|
||||
ASSERT_STATUS_OK(utils::UnpackTensor(tensor_proto, actual_data.data(), actual_data.size()));
|
||||
ASSERT_STATUS_OK(utils::UnpackTensor(tensor_proto, Path{}, actual_data.data(), actual_data.size()));
|
||||
|
||||
ASSERT_EQ(gsl::make_span(actual_data), expected_data);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ TEST(ConstantOfShape, TypeTests) {
|
|||
//RunTypedTest(TensorProto::INT16, int16_t(16));
|
||||
|
||||
RunTypedTest(TensorProto::FLOAT, 1.f);
|
||||
RunTypedTest(TensorProto::FLOAT16, MLFloat16(5));
|
||||
RunTypedTest(TensorProto::FLOAT16, MLFloat16(static_cast<uint16_t>(5)));
|
||||
RunTypedTest(TensorProto::DOUBLE, 1.0);
|
||||
RunTypedTest(TensorProto::INT32, int32_t(32));
|
||||
RunTypedTest(TensorProto::INT64, int64_t(64));
|
||||
|
|
|
|||
|
|
@ -291,8 +291,8 @@ TEST(TensorOpTest, CastFromFloat16) {
|
|||
|
||||
TEST(TensorOpTest, CastFromString) {
|
||||
const std::vector<int64_t> shape{2, 2, 2};
|
||||
std::initializer_list<std::string> string_data = {"-inf", "+INF", "0.9767611f", "0.28280696f",
|
||||
"-0.12019656f", "5.0f", "NaN", "nan"};
|
||||
std::initializer_list<std::string> string_data = {"-inf", "+INF", "0.9767611", "0.28280696",
|
||||
"-0.12019656", "5.0", "NaN", "nan"};
|
||||
const std::initializer_list<float> float_output = {-(std::numeric_limits<float>::infinity()), std::numeric_limits<float>::infinity(),
|
||||
0.9767611f, 0.28280696f,
|
||||
-0.12019656f, 5.0f, NAN, NAN};
|
||||
|
|
|
|||
|
|
@ -152,9 +152,10 @@ TEST(TransposeOpTest, TwoDim_mlfloat16) {
|
|||
|
||||
std::vector<int64_t> perm = {1, 0};
|
||||
std::vector<int64_t> expected_shape({3, 2});
|
||||
std::initializer_list<MLFloat16> expected_vals = {MLFloat16(1), MLFloat16(4),
|
||||
MLFloat16(2), MLFloat16(5),
|
||||
MLFloat16(3), MLFloat16(6)};
|
||||
std::initializer_list<MLFloat16> expected_vals =
|
||||
{MLFloat16{static_cast<uint16_t>(1)}, MLFloat16{static_cast<uint16_t>(4)},
|
||||
MLFloat16{static_cast<uint16_t>(2)}, MLFloat16{static_cast<uint16_t>(5)},
|
||||
MLFloat16{static_cast<uint16_t>(3)}, MLFloat16{static_cast<uint16_t>(6)}};
|
||||
|
||||
TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, false);
|
||||
}
|
||||
|
|
|
|||
71
onnxruntime/test/providers/op_kernel_type_control_test.cc
Normal file
71
onnxruntime/test/providers/op_kernel_type_control_test.cc
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/op_kernel_type_control.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "boost/mp11.hpp"
|
||||
|
||||
#include "core/common/type_list.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
template <typename A, typename B>
|
||||
struct TypeSetsEqual {
|
||||
private:
|
||||
static_assert(boost::mp11::mp_is_set<A>::value && boost::mp11::mp_is_set<B>::value,
|
||||
"A and B must both be sets.");
|
||||
using ABIntersection = boost::mp11::mp_set_intersection<A, B>;
|
||||
|
||||
public:
|
||||
static constexpr bool value =
|
||||
(boost::mp11::mp_size<A>::value == boost::mp11::mp_size<B>::value) &&
|
||||
(boost::mp11::mp_size<ABIntersection>::value == boost::mp11::mp_size<A>::value);
|
||||
};
|
||||
|
||||
// test types to match op_kernel_type_control::TypesHolder
|
||||
template <typename... T>
|
||||
struct TestTypesHolder {
|
||||
using types = TypeList<T...>;
|
||||
};
|
||||
|
||||
struct TestTypesHolderUnspecified {
|
||||
};
|
||||
|
||||
// supported + allowed for Op
|
||||
static_assert(
|
||||
TypeSetsEqual<
|
||||
op_kernel_type_control::EnabledTypes<
|
||||
TestTypesHolder<int32_t, int64_t, float, double>,
|
||||
TypeList<
|
||||
TestTypesHolder<float, int64_t, char>,
|
||||
TestTypesHolderUnspecified>>::types,
|
||||
TypeList<int64_t, float>>::value,
|
||||
"unexpected enabled types: supported + allowed + unspecified allowed");
|
||||
|
||||
// supported + allowed for Op + allowed globally
|
||||
static_assert(
|
||||
TypeSetsEqual<
|
||||
op_kernel_type_control::EnabledTypes<
|
||||
TestTypesHolder<int32_t, int64_t, float, double>,
|
||||
TypeList<
|
||||
TestTypesHolder<float, int64_t, char>,
|
||||
TestTypesHolder<int64_t>>>::types,
|
||||
TypeList<int64_t>>::value,
|
||||
"unexpected enabled types: supported + allowed + allowed");
|
||||
|
||||
// supported
|
||||
static_assert(
|
||||
TypeSetsEqual<
|
||||
op_kernel_type_control::EnabledTypes<
|
||||
TestTypesHolder<int32_t, int64_t, float, double>,
|
||||
TypeList<
|
||||
TestTypesHolderUnspecified,
|
||||
TestTypesHolderUnspecified>>::types,
|
||||
TypeList<int32_t, int64_t, float, double>>::value,
|
||||
"unexpected enabled types: supported + unspecified allowed + unspecified allowed");
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue