mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Enable type reduction in EyeLike, Mod, random.cc CPU kernels. (#6960)
* Update EyeLike CPU kernel. * Update Mod CPU kernel. * Update Multinomial CPU kernel. * Slight improvement to Pad CPU kernel binary size. * Update RandomNormal[Like], RandomUniform[Like] CPU kernels.
This commit is contained in:
parent
89916fdb05
commit
d5ed3e7fba
6 changed files with 240 additions and 151 deletions
|
|
@ -344,7 +344,7 @@ class MLTypeCallDispatcher {
|
|||
* @tparam Fn The function object template.
|
||||
* @tparam Args The argument types.
|
||||
*/
|
||||
template <template <typename> class Fn, typename... Args>
|
||||
template <template <typename...> class Fn, typename... Args>
|
||||
void Invoke(Args&&... args) const {
|
||||
InvokeWithLeadingTemplateArgs<Fn, TypeList<>>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
|
@ -384,7 +384,7 @@ class MLTypeCallDispatcher {
|
|||
* @tparam Fn The function object template.
|
||||
* @tparam Args The argument types.
|
||||
*/
|
||||
template <class Ret, template <typename> class Fn, typename... Args>
|
||||
template <class Ret, template <typename...> class Fn, typename... Args>
|
||||
Ret InvokeRet(Args&&... args) const {
|
||||
return InvokeRetWithUnsupportedPolicy<
|
||||
Ret, Fn, mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy<Ret>>(
|
||||
|
|
@ -401,7 +401,7 @@ class MLTypeCallDispatcher {
|
|||
* for an example.
|
||||
* @tparam Args The argument types.
|
||||
*/
|
||||
template <class Ret, template <typename> class Fn, class UnsupportedPolicy, typename... Args>
|
||||
template <class Ret, template <typename...> class Fn, class UnsupportedPolicy, typename... Args>
|
||||
Ret InvokeRetWithUnsupportedPolicy(Args&&... args) const {
|
||||
mltype_dispatcher_internal::CallableDispatchableRetHelper<Ret, UnsupportedPolicy> helper(dt_type_);
|
||||
|
||||
|
|
|
|||
|
|
@ -25,47 +25,122 @@ limitations under the License.
|
|||
#include <chrono>
|
||||
#include <random>
|
||||
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include "core/common/eigen_common_wrapper.h"
|
||||
#include "gsl/gsl"
|
||||
|
||||
#include "core/common/eigen_common_wrapper.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/providers/op_kernel_type_control.h"
|
||||
#include "core/providers/op_kernel_type_control_utils.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace ::onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace op_kernel_type_control {
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, RandomNormal, Output, 0,
|
||||
float, double);
|
||||
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, RandomUniform, Output, 0,
|
||||
float, double);
|
||||
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, RandomNormalLike, Output, 0,
|
||||
float, double);
|
||||
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, RandomUniformLike, Output, 0,
|
||||
float, double);
|
||||
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, Multinomial, Output, 0,
|
||||
int32_t, int64_t);
|
||||
}
|
||||
|
||||
using RandomNormalOutputTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, RandomNormal, Output, 0);
|
||||
using EnabledRandomNormalOutputTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, RandomNormal, Output, 0);
|
||||
|
||||
using RandomUniformOutputTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, RandomUniform, Output, 0);
|
||||
using EnabledRandomUniformOutputTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, RandomUniform, Output, 0);
|
||||
|
||||
using RandomNormalLikeOutputTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, RandomNormalLike, Output, 0);
|
||||
using EnabledRandomNormalLikeOutputTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, RandomNormalLike, Output, 0);
|
||||
|
||||
using RandomUniformLikeOutputTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, RandomUniformLike, Output, 0);
|
||||
using EnabledRandomUniformLikeOutputTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, RandomUniformLike, Output, 0);
|
||||
|
||||
using MultinomialOutputTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, Multinomial, Output, 0);
|
||||
using EnabledMultinomialOutputTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, Multinomial, Output, 0);
|
||||
|
||||
using EnabledRandomUniformComputeOutputTypes =
|
||||
utils::TypeSetUnion<
|
||||
EnabledRandomUniformOutputTypes,
|
||||
EnabledRandomUniformLikeOutputTypes>;
|
||||
|
||||
using EnabledRandomNormalComputeOutputTypes =
|
||||
utils::TypeSetUnion<
|
||||
EnabledRandomNormalOutputTypes,
|
||||
EnabledRandomNormalLikeOutputTypes>;
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
RandomNormal,
|
||||
1,
|
||||
KernelDefBuilder().TypeConstraint("T", std::vector<MLDataType>{
|
||||
DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<double>()}),
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T",
|
||||
BuildKernelDefConstraintsFromTypeList<RandomNormalOutputTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledRandomNormalOutputTypes>()),
|
||||
RandomNormal);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
RandomUniform,
|
||||
1,
|
||||
KernelDefBuilder().TypeConstraint("T", std::vector<MLDataType>{
|
||||
DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<double>()}),
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T",
|
||||
BuildKernelDefConstraintsFromTypeList<RandomUniformOutputTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledRandomUniformOutputTypes>()),
|
||||
RandomUniform);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
RandomNormalLike,
|
||||
1,
|
||||
KernelDefBuilder().TypeConstraint("T1", DataTypeImpl::AllTensorTypes()).TypeConstraint("T2", std::vector<MLDataType>{DataTypeImpl::GetTensorType<float>(), DataTypeImpl::GetTensorType<double>()}),
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T1", DataTypeImpl::AllTensorTypes())
|
||||
.TypeConstraint("T2",
|
||||
BuildKernelDefConstraintsFromTypeList<RandomNormalLikeOutputTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledRandomNormalLikeOutputTypes>()),
|
||||
RandomNormalLike);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
RandomUniformLike,
|
||||
1,
|
||||
KernelDefBuilder().TypeConstraint("T1", DataTypeImpl::AllTensorTypes()).TypeConstraint("T2", std::vector<MLDataType>{DataTypeImpl::GetTensorType<float>(), DataTypeImpl::GetTensorType<double>()}),
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T1", DataTypeImpl::AllTensorTypes())
|
||||
.TypeConstraint("T2",
|
||||
BuildKernelDefConstraintsFromTypeList<RandomUniformLikeOutputTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledRandomUniformLikeOutputTypes>()),
|
||||
RandomUniformLike);
|
||||
|
||||
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#multinomial
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Multinomial,
|
||||
7,
|
||||
KernelDefBuilder().TypeConstraint("T1", DataTypeImpl::GetTensorType<float>()).TypeConstraint("T2", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
|
||||
.TypeConstraint("T2",
|
||||
BuildKernelDefConstraintsFromTypeList<MultinomialOutputTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledMultinomialOutputTypes>()),
|
||||
Multinomial);
|
||||
|
||||
template <typename T, typename TDistribution>
|
||||
|
|
@ -156,6 +231,10 @@ static Status MultinomialCompute(OpKernelContext* ctx,
|
|||
const int64_t num_samples,
|
||||
std::default_random_engine& generator,
|
||||
Tensor& Y) {
|
||||
if (!utils::HasType<EnabledMultinomialOutputTypes, OutputType>()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output type not supported in this build.");
|
||||
}
|
||||
|
||||
// implementation copied from Tensorflow with some changes such as using the std::uniform_real_distribution
|
||||
// instead of the Philox RNG.
|
||||
Eigen::array<int64_t, 2> X_dims = {{batch_size, num_classes}};
|
||||
|
|
@ -275,22 +354,30 @@ static TensorProto::DataType InferDataType(const Tensor& tensor) {
|
|||
static Status RandomNormalCompute(float mean, float scale,
|
||||
std::default_random_engine& generator,
|
||||
TensorProto::DataType dtype, Tensor& Y) {
|
||||
bool handled = false;
|
||||
switch (dtype) {
|
||||
case TensorProto::FLOAT: {
|
||||
GenerateData<float, std::normal_distribution<float>>(
|
||||
generator, std::normal_distribution<float>{mean, scale}, Y);
|
||||
if (utils::HasType<EnabledRandomNormalComputeOutputTypes, float>()) {
|
||||
GenerateData<float, std::normal_distribution<float>>(
|
||||
generator, std::normal_distribution<float>{mean, scale}, Y);
|
||||
handled = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case TensorProto::FLOAT16: {
|
||||
ORT_NOT_IMPLEMENTED("FLOAT16 is not supported");
|
||||
}
|
||||
case TensorProto::DOUBLE: {
|
||||
GenerateData<double, std::normal_distribution<double>>(
|
||||
generator, std::normal_distribution<double>{mean, scale}, Y);
|
||||
if (utils::HasType<EnabledRandomNormalComputeOutputTypes, double>()) {
|
||||
GenerateData<double, std::normal_distribution<double>>(
|
||||
generator, std::normal_distribution<double>{mean, scale}, Y);
|
||||
handled = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
ORT_THROW("Invalid data type of ", dtype);
|
||||
break;
|
||||
}
|
||||
|
||||
if (!handled) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output type not supported in this build: ", dtype);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
@ -300,22 +387,30 @@ static Status RandomUniformCompute(float low, float high,
|
|||
std::default_random_engine& generator,
|
||||
TensorProto::DataType dtype,
|
||||
Tensor& Y) {
|
||||
bool handled = false;
|
||||
switch (dtype) {
|
||||
case TensorProto::FLOAT: {
|
||||
GenerateData<float, std::uniform_real_distribution<float>>(
|
||||
generator, std::uniform_real_distribution<float>{low, high}, Y);
|
||||
if (utils::HasType<EnabledRandomUniformComputeOutputTypes, float>()) {
|
||||
GenerateData<float, std::uniform_real_distribution<float>>(
|
||||
generator, std::uniform_real_distribution<float>{low, high}, Y);
|
||||
handled = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case TensorProto::FLOAT16: {
|
||||
ORT_NOT_IMPLEMENTED("FLOAT16 is not supported");
|
||||
}
|
||||
case TensorProto::DOUBLE: {
|
||||
GenerateData<double, std::uniform_real_distribution<double>>(
|
||||
generator, std::uniform_real_distribution<double>{low, high}, Y);
|
||||
if (utils::HasType<EnabledRandomUniformComputeOutputTypes, double>()) {
|
||||
GenerateData<double, std::uniform_real_distribution<double>>(
|
||||
generator, std::uniform_real_distribution<double>{low, high}, Y);
|
||||
handled = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
ORT_THROW("Invalid data type of ", dtype);
|
||||
break;
|
||||
}
|
||||
|
||||
if (!handled) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output type not supported in this build: ", dtype);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -29,6 +29,11 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(kCpuExecutionProvider, kOnnxDomain, Min,
|
|||
ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES(kCpuExecutionProvider, kOnnxDomain, Min, 12, Input, 0,
|
||||
int64_t);
|
||||
|
||||
// Mod
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain, Mod, Input, 0,
|
||||
float, double, int64_t, uint64_t, int32_t, uint32_t,
|
||||
int16_t, uint16_t, int8_t, uint8_t, MLFloat16);
|
||||
|
||||
// Pow
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(kCpuExecutionProvider, kOnnxDomain, Pow, 7, Input, 0, float, double);
|
||||
|
||||
|
|
@ -54,6 +59,10 @@ using Min12Types = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(kCpuExecutionProvider, kO
|
|||
using EnabledMin8Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Min, 8, Input, 0);
|
||||
using EnabledMin12Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Min, 12, Input, 0);
|
||||
|
||||
using ModTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain, Mod, Input, 0);
|
||||
using EnabledModTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, Mod, Input, 0);
|
||||
|
||||
using Pow7Types = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Pow, 7, Input, 0);
|
||||
using Pow12BaseTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Pow, 12, Input, 0);
|
||||
using Pow12ExpTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Pow, 12, Input, 1);
|
||||
|
|
@ -1491,33 +1500,21 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
|||
Mod,
|
||||
10,
|
||||
12,
|
||||
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<double>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>(),
|
||||
DataTypeImpl::GetTensorType<uint64_t>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<uint32_t>(),
|
||||
DataTypeImpl::GetTensorType<int16_t>(),
|
||||
DataTypeImpl::GetTensorType<uint16_t>(),
|
||||
DataTypeImpl::GetTensorType<int8_t>(),
|
||||
DataTypeImpl::GetTensorType<uint8_t>(),
|
||||
DataTypeImpl::GetTensorType<MLFloat16>()}),
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
BuildKernelDefConstraintsFromTypeList<ModTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledModTypes>()),
|
||||
Mod);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Mod,
|
||||
13,
|
||||
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<double>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>(),
|
||||
DataTypeImpl::GetTensorType<uint64_t>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<uint32_t>(),
|
||||
DataTypeImpl::GetTensorType<int16_t>(),
|
||||
DataTypeImpl::GetTensorType<uint16_t>(),
|
||||
DataTypeImpl::GetTensorType<int8_t>(),
|
||||
DataTypeImpl::GetTensorType<uint8_t>(),
|
||||
DataTypeImpl::GetTensorType<MLFloat16>()}),
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
BuildKernelDefConstraintsFromTypeList<ModTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledModTypes>()),
|
||||
Mod);
|
||||
|
||||
namespace mod_internal {
|
||||
|
|
@ -1605,7 +1602,7 @@ void BroadCastMod(OpKernelContext* context) {
|
|||
UntypedBroadcastTwo(*context, funcs);
|
||||
}
|
||||
|
||||
void BroadCastMFloat16FMod(OpKernelContext* context) {
|
||||
void BroadCastMLFloat16FMod(OpKernelContext* context) {
|
||||
ProcessBroadcastSpanFuncs funcs{
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
const auto X = per_iter_bh.ScalarInput0<MLFloat16>();
|
||||
|
|
@ -1643,9 +1640,12 @@ void BroadCastMFloat16FMod(OpKernelContext* context) {
|
|||
UntypedBroadcastTwo(*context, funcs);
|
||||
}
|
||||
|
||||
// Generic implementation of Mod kernel
|
||||
template <class T, typename Enable = void>
|
||||
struct CallModImpl;
|
||||
|
||||
// Generic implementation of Mod kernel, non-floating point types
|
||||
template <class T>
|
||||
struct CallModImpl {
|
||||
struct CallModImpl<T, typename std::enable_if<!std::is_floating_point<T>::value>::type> {
|
||||
void operator()(bool fmod, OpKernelContext* ctx) const {
|
||||
if (fmod) {
|
||||
BroadCastFMod<T>(ctx);
|
||||
|
|
@ -1655,32 +1655,32 @@ struct CallModImpl {
|
|||
}
|
||||
};
|
||||
|
||||
// Generic implementation of Mod kernel, floating point types
|
||||
template <class T>
|
||||
struct CallModImpl<T, typename std::enable_if<std::is_floating_point<T>::value, void>::type> {
|
||||
void operator()(bool fmod, OpKernelContext* ctx) const {
|
||||
ORT_ENFORCE(fmod, "fmod attribute must be true for floating point types");
|
||||
BroadCastFMod<T>(ctx);
|
||||
}
|
||||
};
|
||||
|
||||
// MLFloat16 implementation of Mod kernel
|
||||
template <>
|
||||
struct CallModImpl<MLFloat16> {
|
||||
void operator()(bool fmod, OpKernelContext* ctx) const {
|
||||
ORT_ENFORCE(fmod, "fmod attribute must be true for floating point types");
|
||||
BroadCastMLFloat16FMod(ctx);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mod_internal
|
||||
|
||||
Status Mod::Compute(OpKernelContext* context) const {
|
||||
const auto& X = *context->Input<Tensor>(0);
|
||||
auto dt_type = X.GetElementType();
|
||||
const auto dt_type = X.GetElementType();
|
||||
|
||||
switch (dt_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
|
||||
ORT_ENFORCE(fmod_, "fmod attribute must be true for float, float16 and double types");
|
||||
mod_internal::BroadCastFMod<float>(context);
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
|
||||
ORT_ENFORCE(fmod_, "fmod attribute must be true for float, float16 and double types");
|
||||
mod_internal::BroadCastFMod<double>(context);
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
ORT_ENFORCE(fmod_, "fmod attribute must be true for float, float16 and double types");
|
||||
mod_internal::BroadCastMFloat16FMod(context);
|
||||
break;
|
||||
default:
|
||||
utils::MLTypeCallDispatcher<uint8_t, int8_t, uint16_t, int16_t,
|
||||
uint32_t, int32_t, uint64_t, int64_t>
|
||||
t_disp(dt_type);
|
||||
t_disp.Invoke<mod_internal::CallModImpl>(fmod_, context);
|
||||
break;
|
||||
}
|
||||
utils::MLTypeCallDispatcherFromTypeList<EnabledModTypes> t_disp(dt_type);
|
||||
t_disp.Invoke<mod_internal::CallModImpl>(fmod_, context);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,74 +2,77 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cpu/tensor/eye_like.h"
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/op_kernel_type_control.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
|
||||
using namespace ::onnxruntime::common;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace op_kernel_type_control {
|
||||
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, EyeLike, Output, 0,
|
||||
float, double, uint64_t, int64_t, int32_t);
|
||||
}
|
||||
|
||||
using EyeLikeDataTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, EyeLike, Output, 0);
|
||||
using EnabledEyeLikeDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
|
||||
kCpuExecutionProvider, kOnnxDomain, EyeLike, Output, 0);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
EyeLike,
|
||||
9,
|
||||
KernelDefBuilder().TypeConstraint("T1",
|
||||
std::vector<MLDataType>{
|
||||
DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<double>(),
|
||||
DataTypeImpl::GetTensorType<uint64_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>()
|
||||
})
|
||||
.TypeConstraint("T2",
|
||||
std::vector<MLDataType>{
|
||||
DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<double>(),
|
||||
DataTypeImpl::GetTensorType<uint64_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>()
|
||||
}),
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint(
|
||||
"T1",
|
||||
BuildKernelDefConstraintsFromTypeList<EyeLikeDataTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledEyeLikeDataTypes>())
|
||||
.TypeConstraint(
|
||||
"T2",
|
||||
BuildKernelDefConstraintsFromTypeList<EyeLikeDataTypes>(),
|
||||
BuildKernelDefConstraintsFromTypeList<EnabledEyeLikeDataTypes>()),
|
||||
EyeLike);
|
||||
|
||||
Status EyeLike::Compute(OpKernelContext* context) const {
|
||||
const auto* T1 = context->Input<Tensor>(0);
|
||||
ORT_ENFORCE(T1 != nullptr);
|
||||
|
||||
auto output_tensor_dtype = has_dtype_ ? static_cast<ONNX_NAMESPACE::TensorProto::DataType>(dtype_) : T1->GetElementType();
|
||||
switch (output_tensor_dtype) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
|
||||
return ComputeImpl<float>(context, T1);
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
|
||||
return ComputeImpl<double>(context, T1);
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
|
||||
return ComputeImpl<int32_t>(context, T1);
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
|
||||
return ComputeImpl<uint64_t>(context, T1);
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
|
||||
return ComputeImpl<int64_t>(context, T1);
|
||||
default:
|
||||
ORT_THROW("Unsupported 'dtype' value: ", output_tensor_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
Status EyeLike::ComputeImpl(OpKernelContext* context, const Tensor* T1) const {
|
||||
const std::vector<int64_t>& input_dims = T1->Shape().GetDims();
|
||||
if (input_dims.size() != 2) {
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "EyeLike : Input tensor dimension is not 2");
|
||||
struct ComputeDispatchTarget {
|
||||
void operator()(const int64_t k, Tensor& output) {
|
||||
const auto& output_shape = output.Shape();
|
||||
auto output_mat = EigenMatrixMapRowMajor<T>(
|
||||
output.template MutableData<T>(),
|
||||
output_shape[0],
|
||||
output_shape[1]);
|
||||
|
||||
output_mat.setZero();
|
||||
|
||||
if ((k >= 0 && k >= output_shape[1]) || (k < 0 && std::abs(k) >= output_shape[0])) {
|
||||
return;
|
||||
}
|
||||
|
||||
output_mat.diagonal(k).array() = static_cast<T>(1);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
Status EyeLike::Compute(OpKernelContext* context) const {
|
||||
const auto& T1 = context->RequiredInput<Tensor>(0);
|
||||
|
||||
const auto& input_shape = T1.Shape();
|
||||
if (input_shape.NumDimensions() != 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "EyeLike : Input tensor dimension is not 2");
|
||||
}
|
||||
|
||||
// set output tensor shape same as input tensor and set all values to zero
|
||||
auto* T2 = context->Output(0, input_dims);
|
||||
auto output_mat = EigenMatrixMapRowMajor<T>(
|
||||
T2->template MutableData<T>(),
|
||||
input_dims[0],
|
||||
input_dims[1]);
|
||||
output_mat.setZero();
|
||||
|
||||
if ((k_ >= 0 && k_ >= input_dims[1]) || (k_ < 0 && std::abs(k_) >= input_dims[0])) {
|
||||
return Status::OK();
|
||||
}
|
||||
output_mat.diagonal(k_).array() = static_cast<T>(1);
|
||||
// set output tensor shape same as input tensor
|
||||
auto& T2 = context->RequiredOutput(0, input_shape);
|
||||
|
||||
const auto output_tensor_dtype =
|
||||
has_dtype_ ? static_cast<ONNX_NAMESPACE::TensorProto::DataType>(dtype_) : T1.GetElementType();
|
||||
|
||||
utils::MLTypeCallDispatcherFromTypeList<EnabledEyeLikeDataTypes> dispatcher{output_tensor_dtype};
|
||||
dispatcher.Invoke<ComputeDispatchTarget>(k_, T2);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -21,9 +21,6 @@ class EyeLike final : public OpKernel {
|
|||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
Status ComputeImpl(OpKernelContext* context, const Tensor* T1) const;
|
||||
|
||||
bool has_dtype_;
|
||||
int64_t dtype_;
|
||||
int64_t k_;
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
#include "core/providers/cpu/tensor/pad.h"
|
||||
|
||||
#include "core/common/optional.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "core/providers/op_kernel_type_control.h"
|
||||
#include "core/providers/op_kernel_type_control_utils.h"
|
||||
|
|
@ -245,6 +244,10 @@ static Status PadImpl(OpKernelContext* ctx,
|
|||
const std::vector<int64_t>& slices,
|
||||
const Mode& mode,
|
||||
T value) {
|
||||
if (!utils::HasType<AllEnabledPadTypes, T>()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input data type not supported in this build.");
|
||||
}
|
||||
|
||||
const auto& input_tensor = *ctx->Input<Tensor>(0);
|
||||
const auto& orig_input_shape = input_tensor.Shape();
|
||||
std::vector<int64_t> output_dims(orig_input_shape.GetDims());
|
||||
|
|
@ -514,31 +517,22 @@ Status Pad::Compute(OpKernelContext* ctx) const {
|
|||
slices_to_use = &slices_;
|
||||
}
|
||||
|
||||
optional<Status> pad_status{};
|
||||
Status pad_status{};
|
||||
switch (element_size) {
|
||||
case sizeof(uint32_t):
|
||||
if (utils::HasTypeWithSameSize<AllEnabledPadTypes, uint32_t>()) {
|
||||
pad_status = PadImpl<uint32_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u32);
|
||||
}
|
||||
pad_status = PadImpl<uint32_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u32);
|
||||
break;
|
||||
case sizeof(uint64_t):
|
||||
if (utils::HasTypeWithSameSize<AllEnabledPadTypes, uint64_t>()) {
|
||||
pad_status = PadImpl<uint64_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u64);
|
||||
}
|
||||
pad_status = PadImpl<uint64_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u64);
|
||||
break;
|
||||
case sizeof(uint8_t):
|
||||
if (utils::HasTypeWithSameSize<AllEnabledPadTypes, uint8_t>()) {
|
||||
pad_status = PadImpl<uint8_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u8);
|
||||
}
|
||||
pad_status = PadImpl<uint8_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u8);
|
||||
break;
|
||||
default:
|
||||
pad_status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported input data type of ", data_type);
|
||||
break;
|
||||
}
|
||||
|
||||
if (!pad_status) {
|
||||
pad_status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported input data type of ", data_type);
|
||||
}
|
||||
|
||||
return *pad_status;
|
||||
return pad_status;
|
||||
}
|
||||
}; // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue