Enable type reduction in EyeLike, Mod, random.cc CPU kernels. (#6960)

* Update EyeLike CPU kernel.

* Update Mod CPU kernel.

* Update Multinomial CPU kernel.

* Slight improvement to Pad CPU kernel binary size.

* Update RandomNormal[Like], RandomUniform[Like] CPU kernels.
This commit is contained in:
Edward Chen 2021-03-09 21:32:56 -08:00 committed by GitHub
parent 89916fdb05
commit d5ed3e7fba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 240 additions and 151 deletions

View file

@ -344,7 +344,7 @@ class MLTypeCallDispatcher {
* @tparam Fn The function object template.
* @tparam Args The argument types.
*/
template <template <typename> class Fn, typename... Args>
template <template <typename...> class Fn, typename... Args>
void Invoke(Args&&... args) const {
InvokeWithLeadingTemplateArgs<Fn, TypeList<>>(std::forward<Args>(args)...);
}
@ -384,7 +384,7 @@ class MLTypeCallDispatcher {
* @tparam Fn The function object template.
* @tparam Args The argument types.
*/
template <class Ret, template <typename> class Fn, typename... Args>
template <class Ret, template <typename...> class Fn, typename... Args>
Ret InvokeRet(Args&&... args) const {
return InvokeRetWithUnsupportedPolicy<
Ret, Fn, mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy<Ret>>(
@ -401,7 +401,7 @@ class MLTypeCallDispatcher {
* for an example.
* @tparam Args The argument types.
*/
template <class Ret, template <typename> class Fn, class UnsupportedPolicy, typename... Args>
template <class Ret, template <typename...> class Fn, class UnsupportedPolicy, typename... Args>
Ret InvokeRetWithUnsupportedPolicy(Args&&... args) const {
mltype_dispatcher_internal::CallableDispatchableRetHelper<Ret, UnsupportedPolicy> helper(dt_type_);

View file

@ -25,47 +25,122 @@ limitations under the License.
#include <chrono>
#include <random>
#include "core/common/safeint.h"
#include "core/util/math_cpuonly.h"
#include "core/common/eigen_common_wrapper.h"
#include "gsl/gsl"
#include "core/common/eigen_common_wrapper.h"
#include "core/common/safeint.h"
#include "core/providers/op_kernel_type_control.h"
#include "core/providers/op_kernel_type_control_utils.h"
#include "core/util/math_cpuonly.h"
using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
namespace onnxruntime {
namespace op_kernel_type_control {
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, RandomNormal, Output, 0,
float, double);
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, RandomUniform, Output, 0,
float, double);
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, RandomNormalLike, Output, 0,
float, double);
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, RandomUniformLike, Output, 0,
float, double);
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Multinomial, Output, 0,
int32_t, int64_t);
}
using RandomNormalOutputTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, RandomNormal, Output, 0);
using EnabledRandomNormalOutputTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, RandomNormal, Output, 0);
using RandomUniformOutputTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, RandomUniform, Output, 0);
using EnabledRandomUniformOutputTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, RandomUniform, Output, 0);
using RandomNormalLikeOutputTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, RandomNormalLike, Output, 0);
using EnabledRandomNormalLikeOutputTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, RandomNormalLike, Output, 0);
using RandomUniformLikeOutputTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, RandomUniformLike, Output, 0);
using EnabledRandomUniformLikeOutputTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, RandomUniformLike, Output, 0);
using MultinomialOutputTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Multinomial, Output, 0);
using EnabledMultinomialOutputTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Multinomial, Output, 0);
using EnabledRandomUniformComputeOutputTypes =
utils::TypeSetUnion<
EnabledRandomUniformOutputTypes,
EnabledRandomUniformLikeOutputTypes>;
using EnabledRandomNormalComputeOutputTypes =
utils::TypeSetUnion<
EnabledRandomNormalOutputTypes,
EnabledRandomNormalLikeOutputTypes>;
ONNX_CPU_OPERATOR_KERNEL(
RandomNormal,
1,
KernelDefBuilder().TypeConstraint("T", std::vector<MLDataType>{
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>()}),
KernelDefBuilder()
.TypeConstraint("T",
BuildKernelDefConstraintsFromTypeList<RandomNormalOutputTypes>(),
BuildKernelDefConstraintsFromTypeList<EnabledRandomNormalOutputTypes>()),
RandomNormal);
ONNX_CPU_OPERATOR_KERNEL(
RandomUniform,
1,
KernelDefBuilder().TypeConstraint("T", std::vector<MLDataType>{
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>()}),
KernelDefBuilder()
.TypeConstraint("T",
BuildKernelDefConstraintsFromTypeList<RandomUniformOutputTypes>(),
BuildKernelDefConstraintsFromTypeList<EnabledRandomUniformOutputTypes>()),
RandomUniform);
ONNX_CPU_OPERATOR_KERNEL(
RandomNormalLike,
1,
KernelDefBuilder().TypeConstraint("T1", DataTypeImpl::AllTensorTypes()).TypeConstraint("T2", std::vector<MLDataType>{DataTypeImpl::GetTensorType<float>(), DataTypeImpl::GetTensorType<double>()}),
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::AllTensorTypes())
.TypeConstraint("T2",
BuildKernelDefConstraintsFromTypeList<RandomNormalLikeOutputTypes>(),
BuildKernelDefConstraintsFromTypeList<EnabledRandomNormalLikeOutputTypes>()),
RandomNormalLike);
ONNX_CPU_OPERATOR_KERNEL(
RandomUniformLike,
1,
KernelDefBuilder().TypeConstraint("T1", DataTypeImpl::AllTensorTypes()).TypeConstraint("T2", std::vector<MLDataType>{DataTypeImpl::GetTensorType<float>(), DataTypeImpl::GetTensorType<double>()}),
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::AllTensorTypes())
.TypeConstraint("T2",
BuildKernelDefConstraintsFromTypeList<RandomUniformLikeOutputTypes>(),
BuildKernelDefConstraintsFromTypeList<EnabledRandomUniformLikeOutputTypes>()),
RandomUniformLike);
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#multinomial
ONNX_CPU_OPERATOR_KERNEL(
Multinomial,
7,
KernelDefBuilder().TypeConstraint("T1", DataTypeImpl::GetTensorType<float>()).TypeConstraint("T2", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}),
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T2",
BuildKernelDefConstraintsFromTypeList<MultinomialOutputTypes>(),
BuildKernelDefConstraintsFromTypeList<EnabledMultinomialOutputTypes>()),
Multinomial);
template <typename T, typename TDistribution>
@ -156,6 +231,10 @@ static Status MultinomialCompute(OpKernelContext* ctx,
const int64_t num_samples,
std::default_random_engine& generator,
Tensor& Y) {
if (!utils::HasType<EnabledMultinomialOutputTypes, OutputType>()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output type not supported in this build.");
}
// implementation copied from Tensorflow with some changes such as using the std::uniform_real_distribution
// instead of the Philox RNG.
Eigen::array<int64_t, 2> X_dims = {{batch_size, num_classes}};
@ -275,22 +354,30 @@ static TensorProto::DataType InferDataType(const Tensor& tensor) {
static Status RandomNormalCompute(float mean, float scale,
std::default_random_engine& generator,
TensorProto::DataType dtype, Tensor& Y) {
bool handled = false;
switch (dtype) {
case TensorProto::FLOAT: {
GenerateData<float, std::normal_distribution<float>>(
generator, std::normal_distribution<float>{mean, scale}, Y);
if (utils::HasType<EnabledRandomNormalComputeOutputTypes, float>()) {
GenerateData<float, std::normal_distribution<float>>(
generator, std::normal_distribution<float>{mean, scale}, Y);
handled = true;
}
break;
}
case TensorProto::FLOAT16: {
ORT_NOT_IMPLEMENTED("FLOAT16 is not supported");
}
case TensorProto::DOUBLE: {
GenerateData<double, std::normal_distribution<double>>(
generator, std::normal_distribution<double>{mean, scale}, Y);
if (utils::HasType<EnabledRandomNormalComputeOutputTypes, double>()) {
GenerateData<double, std::normal_distribution<double>>(
generator, std::normal_distribution<double>{mean, scale}, Y);
handled = true;
}
break;
}
default:
ORT_THROW("Invalid data type of ", dtype);
break;
}
if (!handled) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output type not supported in this build: ", dtype);
}
return Status::OK();
@ -300,22 +387,30 @@ static Status RandomUniformCompute(float low, float high,
std::default_random_engine& generator,
TensorProto::DataType dtype,
Tensor& Y) {
bool handled = false;
switch (dtype) {
case TensorProto::FLOAT: {
GenerateData<float, std::uniform_real_distribution<float>>(
generator, std::uniform_real_distribution<float>{low, high}, Y);
if (utils::HasType<EnabledRandomUniformComputeOutputTypes, float>()) {
GenerateData<float, std::uniform_real_distribution<float>>(
generator, std::uniform_real_distribution<float>{low, high}, Y);
handled = true;
}
break;
}
case TensorProto::FLOAT16: {
ORT_NOT_IMPLEMENTED("FLOAT16 is not supported");
}
case TensorProto::DOUBLE: {
GenerateData<double, std::uniform_real_distribution<double>>(
generator, std::uniform_real_distribution<double>{low, high}, Y);
if (utils::HasType<EnabledRandomUniformComputeOutputTypes, double>()) {
GenerateData<double, std::uniform_real_distribution<double>>(
generator, std::uniform_real_distribution<double>{low, high}, Y);
handled = true;
}
break;
}
default:
ORT_THROW("Invalid data type of ", dtype);
break;
}
if (!handled) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output type not supported in this build: ", dtype);
}
return Status::OK();

View file

@ -29,6 +29,11 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(kCpuExecutionProvider, kOnnxDomain, Min,
ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES(kCpuExecutionProvider, kOnnxDomain, Min, 12, Input, 0,
int64_t);
// Mod
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain, Mod, Input, 0,
float, double, int64_t, uint64_t, int32_t, uint32_t,
int16_t, uint16_t, int8_t, uint8_t, MLFloat16);
// Pow
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(kCpuExecutionProvider, kOnnxDomain, Pow, 7, Input, 0, float, double);
@ -54,6 +59,10 @@ using Min12Types = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(kCpuExecutionProvider, kO
using EnabledMin8Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Min, 8, Input, 0);
using EnabledMin12Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Min, 12, Input, 0);
using ModTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain, Mod, Input, 0);
using EnabledModTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, Mod, Input, 0);
using Pow7Types = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Pow, 7, Input, 0);
using Pow12BaseTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Pow, 12, Input, 0);
using Pow12ExpTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Pow, 12, Input, 1);
@ -1491,33 +1500,21 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Mod,
10,
12,
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<uint64_t>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<uint32_t>(),
DataTypeImpl::GetTensorType<int16_t>(),
DataTypeImpl::GetTensorType<uint16_t>(),
DataTypeImpl::GetTensorType<int8_t>(),
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<MLFloat16>()}),
KernelDefBuilder()
.TypeConstraint(
"T",
BuildKernelDefConstraintsFromTypeList<ModTypes>(),
BuildKernelDefConstraintsFromTypeList<EnabledModTypes>()),
Mod);
ONNX_CPU_OPERATOR_KERNEL(
Mod,
13,
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<uint64_t>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<uint32_t>(),
DataTypeImpl::GetTensorType<int16_t>(),
DataTypeImpl::GetTensorType<uint16_t>(),
DataTypeImpl::GetTensorType<int8_t>(),
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<MLFloat16>()}),
KernelDefBuilder()
.TypeConstraint(
"T",
BuildKernelDefConstraintsFromTypeList<ModTypes>(),
BuildKernelDefConstraintsFromTypeList<EnabledModTypes>()),
Mod);
namespace mod_internal {
@ -1605,7 +1602,7 @@ void BroadCastMod(OpKernelContext* context) {
UntypedBroadcastTwo(*context, funcs);
}
void BroadCastMFloat16FMod(OpKernelContext* context) {
void BroadCastMLFloat16FMod(OpKernelContext* context) {
ProcessBroadcastSpanFuncs funcs{
[](BroadcastHelper& per_iter_bh) {
const auto X = per_iter_bh.ScalarInput0<MLFloat16>();
@ -1643,9 +1640,12 @@ void BroadCastMFloat16FMod(OpKernelContext* context) {
UntypedBroadcastTwo(*context, funcs);
}
// Generic implementation of Mod kernel
template <class T, typename Enable = void>
struct CallModImpl;
// Generic implementation of Mod kernel, non-floating point types
template <class T>
struct CallModImpl {
struct CallModImpl<T, typename std::enable_if<!std::is_floating_point<T>::value>::type> {
void operator()(bool fmod, OpKernelContext* ctx) const {
if (fmod) {
BroadCastFMod<T>(ctx);
@ -1655,32 +1655,32 @@ struct CallModImpl {
}
};
// Generic implementation of Mod kernel, floating point types
template <class T>
struct CallModImpl<T, typename std::enable_if<std::is_floating_point<T>::value, void>::type> {
void operator()(bool fmod, OpKernelContext* ctx) const {
ORT_ENFORCE(fmod, "fmod attribute must be true for floating point types");
BroadCastFMod<T>(ctx);
}
};
// MLFloat16 implementation of Mod kernel
template <>
struct CallModImpl<MLFloat16> {
void operator()(bool fmod, OpKernelContext* ctx) const {
ORT_ENFORCE(fmod, "fmod attribute must be true for floating point types");
BroadCastMLFloat16FMod(ctx);
}
};
} // namespace mod_internal
Status Mod::Compute(OpKernelContext* context) const {
const auto& X = *context->Input<Tensor>(0);
auto dt_type = X.GetElementType();
const auto dt_type = X.GetElementType();
switch (dt_type) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
ORT_ENFORCE(fmod_, "fmod attribute must be true for float, float16 and double types");
mod_internal::BroadCastFMod<float>(context);
break;
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
ORT_ENFORCE(fmod_, "fmod attribute must be true for float, float16 and double types");
mod_internal::BroadCastFMod<double>(context);
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
ORT_ENFORCE(fmod_, "fmod attribute must be true for float, float16 and double types");
mod_internal::BroadCastMFloat16FMod(context);
break;
default:
utils::MLTypeCallDispatcher<uint8_t, int8_t, uint16_t, int16_t,
uint32_t, int32_t, uint64_t, int64_t>
t_disp(dt_type);
t_disp.Invoke<mod_internal::CallModImpl>(fmod_, context);
break;
}
utils::MLTypeCallDispatcherFromTypeList<EnabledModTypes> t_disp(dt_type);
t_disp.Invoke<mod_internal::CallModImpl>(fmod_, context);
return Status::OK();
}

View file

@ -2,74 +2,77 @@
// Licensed under the MIT License.
#include "core/providers/cpu/tensor/eye_like.h"
#include "core/common/common.h"
#include "core/providers/op_kernel_type_control.h"
#include "core/util/math_cpuonly.h"
using namespace ::onnxruntime::common;
namespace onnxruntime {
namespace op_kernel_type_control {
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, EyeLike, Output, 0,
float, double, uint64_t, int64_t, int32_t);
}
using EyeLikeDataTypes = ORT_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, EyeLike, Output, 0);
using EnabledEyeLikeDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, EyeLike, Output, 0);
ONNX_CPU_OPERATOR_KERNEL(
EyeLike,
9,
KernelDefBuilder().TypeConstraint("T1",
std::vector<MLDataType>{
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<uint64_t>(),
DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<int32_t>()
})
.TypeConstraint("T2",
std::vector<MLDataType>{
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<uint64_t>(),
DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<int32_t>()
}),
KernelDefBuilder()
.TypeConstraint(
"T1",
BuildKernelDefConstraintsFromTypeList<EyeLikeDataTypes>(),
BuildKernelDefConstraintsFromTypeList<EnabledEyeLikeDataTypes>())
.TypeConstraint(
"T2",
BuildKernelDefConstraintsFromTypeList<EyeLikeDataTypes>(),
BuildKernelDefConstraintsFromTypeList<EnabledEyeLikeDataTypes>()),
EyeLike);
Status EyeLike::Compute(OpKernelContext* context) const {
const auto* T1 = context->Input<Tensor>(0);
ORT_ENFORCE(T1 != nullptr);
auto output_tensor_dtype = has_dtype_ ? static_cast<ONNX_NAMESPACE::TensorProto::DataType>(dtype_) : T1->GetElementType();
switch (output_tensor_dtype) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
return ComputeImpl<float>(context, T1);
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
return ComputeImpl<double>(context, T1);
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
return ComputeImpl<int32_t>(context, T1);
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
return ComputeImpl<uint64_t>(context, T1);
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
return ComputeImpl<int64_t>(context, T1);
default:
ORT_THROW("Unsupported 'dtype' value: ", output_tensor_dtype);
}
}
namespace {
template <typename T>
Status EyeLike::ComputeImpl(OpKernelContext* context, const Tensor* T1) const {
const std::vector<int64_t>& input_dims = T1->Shape().GetDims();
if (input_dims.size() != 2) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "EyeLike : Input tensor dimension is not 2");
struct ComputeDispatchTarget {
void operator()(const int64_t k, Tensor& output) {
const auto& output_shape = output.Shape();
auto output_mat = EigenMatrixMapRowMajor<T>(
output.template MutableData<T>(),
output_shape[0],
output_shape[1]);
output_mat.setZero();
if ((k >= 0 && k >= output_shape[1]) || (k < 0 && std::abs(k) >= output_shape[0])) {
return;
}
output_mat.diagonal(k).array() = static_cast<T>(1);
}
};
} // namespace
Status EyeLike::Compute(OpKernelContext* context) const {
const auto& T1 = context->RequiredInput<Tensor>(0);
const auto& input_shape = T1.Shape();
if (input_shape.NumDimensions() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "EyeLike : Input tensor dimension is not 2");
}
// set output tensor shape same as input tensor and set all values to zero
auto* T2 = context->Output(0, input_dims);
auto output_mat = EigenMatrixMapRowMajor<T>(
T2->template MutableData<T>(),
input_dims[0],
input_dims[1]);
output_mat.setZero();
if ((k_ >= 0 && k_ >= input_dims[1]) || (k_ < 0 && std::abs(k_) >= input_dims[0])) {
return Status::OK();
}
output_mat.diagonal(k_).array() = static_cast<T>(1);
// set output tensor shape same as input tensor
auto& T2 = context->RequiredOutput(0, input_shape);
const auto output_tensor_dtype =
has_dtype_ ? static_cast<ONNX_NAMESPACE::TensorProto::DataType>(dtype_) : T1.GetElementType();
utils::MLTypeCallDispatcherFromTypeList<EnabledEyeLikeDataTypes> dispatcher{output_tensor_dtype};
dispatcher.Invoke<ComputeDispatchTarget>(k_, T2);
return Status::OK();
}
} // namespace onnxruntime

View file

@ -21,9 +21,6 @@ class EyeLike final : public OpKernel {
Status Compute(OpKernelContext* context) const override;
private:
template <typename T>
Status ComputeImpl(OpKernelContext* context, const Tensor* T1) const;
bool has_dtype_;
int64_t dtype_;
int64_t k_;

View file

@ -3,7 +3,6 @@
#include "core/providers/cpu/tensor/pad.h"
#include "core/common/optional.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/op_kernel_type_control.h"
#include "core/providers/op_kernel_type_control_utils.h"
@ -245,6 +244,10 @@ static Status PadImpl(OpKernelContext* ctx,
const std::vector<int64_t>& slices,
const Mode& mode,
T value) {
if (!utils::HasType<AllEnabledPadTypes, T>()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input data type not supported in this build.");
}
const auto& input_tensor = *ctx->Input<Tensor>(0);
const auto& orig_input_shape = input_tensor.Shape();
std::vector<int64_t> output_dims(orig_input_shape.GetDims());
@ -514,31 +517,22 @@ Status Pad::Compute(OpKernelContext* ctx) const {
slices_to_use = &slices_;
}
optional<Status> pad_status{};
Status pad_status{};
switch (element_size) {
case sizeof(uint32_t):
if (utils::HasTypeWithSameSize<AllEnabledPadTypes, uint32_t>()) {
pad_status = PadImpl<uint32_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u32);
}
pad_status = PadImpl<uint32_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u32);
break;
case sizeof(uint64_t):
if (utils::HasTypeWithSameSize<AllEnabledPadTypes, uint64_t>()) {
pad_status = PadImpl<uint64_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u64);
}
pad_status = PadImpl<uint64_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u64);
break;
case sizeof(uint8_t):
if (utils::HasTypeWithSameSize<AllEnabledPadTypes, uint8_t>()) {
pad_status = PadImpl<uint8_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u8);
}
pad_status = PadImpl<uint8_t>(ctx, *pads_to_use, *slices_to_use, mode_, value.u8);
break;
default:
pad_status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported input data type of ", data_type);
break;
}
if (!pad_status) {
pad_status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported input data type of ", data_type);
}
return *pad_status;
return pad_status;
}
}; // namespace onnxruntime