mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Update cast op to support string <-> numeric (#379)
* Update cast kernel to support to/from string * Update namespace * Add support for literal numeric case * Update to support -INF test * Update kernel registration for cast * Update ONNX to 1.4.1 * Update registy api * Resolve some comments * Update cast kernel implementation * Resolve comments * Fixed test data in onnx * Update cast kernel implementation * Resolve PR comments * Update cast_op.cc * Update onnx commits info * Update comments
This commit is contained in:
parent
f72474c24b
commit
ec8ac04f30
9 changed files with 323 additions and 150 deletions
|
|
@ -49,7 +49,7 @@
|
|||
"component":{
|
||||
"type":"git",
|
||||
"git":{
|
||||
"commitHash":"8a1319733a5518bd0001842db27e2df53a306eff",
|
||||
"commitHash":"2896c77cfc628f18b6ca6b28e3a380807fa00f53",
|
||||
"repositoryUrl":"https://github.com/onnx/onnx.git"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
2
cmake/external/onnx
vendored
2
cmake/external/onnx
vendored
|
|
@ -1 +1 @@
|
|||
Subproject commit dbf3581835e3a05716e10587511d7ab3b2cdc386
|
||||
Subproject commit 2896c77cfc628f18b6ca6b28e3a380807fa00f53
|
||||
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
#include "contrib_ops/cpu/quantize_linear.h"
|
||||
#include "core/providers/cpu/math/element_wise_ops.h"
|
||||
#include "core/providers/cpu/tensor/cast_op.h"
|
||||
#include "core/providers/common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
|
|||
|
|
@ -134,18 +134,18 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, GRU);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, LSTM);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, RNN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, uint8_t, Cast);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, uint16_t, Cast);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, uint32_t, Cast);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, uint64_t, Cast);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, int8_t, Cast);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, int16_t, Cast);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, int32_t, Cast);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, int64_t, Cast);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, bool, Cast);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Cast);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, double, Cast);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, MLFloat16, Cast);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, uint8_t, Cast);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, uint16_t, Cast);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, uint32_t, Cast);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, uint64_t, Cast);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, int8_t, Cast);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, int16_t, Cast);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, int32_t, Cast);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, int64_t, Cast);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, bool, Cast);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, float, Cast);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, double, Cast);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, MLFloat16, Cast);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 4, Concat);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Crop);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Gather);
|
||||
|
|
@ -233,6 +233,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Con
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MeanVarianceNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Greater);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Less);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, Cast);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN);
|
||||
|
|
@ -384,18 +385,18 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, GRU)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, LSTM)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, RNN)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, uint8_t, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, uint16_t, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, uint32_t, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, uint64_t, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, int8_t, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, int16_t, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, int32_t, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, int64_t, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, bool, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, float, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, double, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, MLFloat16, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, uint8_t, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, uint16_t, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, uint32_t, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, uint64_t, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, int8_t, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, int16_t, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, int32_t, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, int64_t, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, bool, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, float, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, double, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 9, MLFloat16, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 4, Concat)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Crop)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Gather)>());
|
||||
|
|
@ -485,6 +486,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Greater)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Less)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, Cast)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sign)>());
|
||||
|
|
|
|||
|
|
@ -1,13 +1,202 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cpu/tensor/cast_op.h"
|
||||
|
||||
#include <sstream>
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/util/math.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include "Eigen/src/Core/arch/CUDA/Half.h"
|
||||
#include "core/common/common.h"
|
||||
|
||||
#if defined(USE_MLAS) && defined(_M_AMD64)
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#endif
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
namespace onnxruntime {
|
||||
|
||||
template <typename SrcType,
|
||||
typename DstType>
|
||||
inline void CastData(const Tensor* in, Tensor* out, const TensorShape& shape) {
|
||||
auto shape_size = shape.Size();
|
||||
auto in_vector = ConstEigenVectorMap<SrcType>(in->template Data<SrcType>(), shape_size);
|
||||
auto output_vector = EigenVectorMap<DstType>(out->template MutableData<DstType>(), shape_size);
|
||||
output_vector = in_vector.template cast<DstType>();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void CastData<float, MLFloat16>(const Tensor* in, Tensor* out, const TensorShape& shape) {
|
||||
auto out_data = out->template MutableData<MLFloat16>();
|
||||
auto shape_size = shape.Size();
|
||||
auto in_vector = ConstEigenVectorMap<float>(in->template Data<float>(), shape_size);
|
||||
auto output_vector = EigenVectorMap<Eigen::half>(static_cast<Eigen::half*>(static_cast<void*>(out_data)), shape_size);
|
||||
output_vector = in_vector.template cast<Eigen::half>();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void CastData<MLFloat16, float>(const Tensor* in, Tensor* out, const TensorShape& shape) {
|
||||
auto out_data = out->template MutableData<float>();
|
||||
auto in_data = in->template Data<MLFloat16>();
|
||||
auto shape_size = shape.Size();
|
||||
#if defined(USE_MLAS) && defined(_M_AMD64)
|
||||
MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size);
|
||||
#else
|
||||
auto in_vector = ConstEigenVectorMap<Eigen::half>(static_cast<const Eigen::half*>(static_cast<const void*>(in_data)), shape_size);
|
||||
auto output_vector = EigenVectorMap<float>(out_data, shape_size);
|
||||
output_vector = in_vector.template cast<float>();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename SrcType,
|
||||
typename DstType>
|
||||
inline void CastFloat16Data(const Tensor* in, Tensor* out, const TensorShape& shape, const AllocatorPtr& allocator) {
|
||||
ORT_ENFORCE(allocator != nullptr);
|
||||
const int64_t len = shape.Size();
|
||||
ORT_ENFORCE(len > 0);
|
||||
void* buffer = allocator->AllocArray(sizeof(float), len);
|
||||
ORT_ENFORCE(buffer);
|
||||
Tensor tmp_tensor(DataTypeImpl::GetType<float>(), shape, buffer, allocator->Info(), nullptr);
|
||||
if (std::is_same<SrcType, MLFloat16>::value) {
|
||||
CastData<MLFloat16, float>(in, &tmp_tensor, shape); // first cast to float
|
||||
CastData<float, DstType>(&tmp_tensor, out, shape); // then cast to the destination type.
|
||||
} else if (std::is_same<DstType, MLFloat16>::value) {
|
||||
CastData<SrcType, float>(in, &tmp_tensor, shape);
|
||||
CastData<float, MLFloat16>(&tmp_tensor, out, shape);
|
||||
}
|
||||
allocator->Free(buffer);
|
||||
}
|
||||
|
||||
template <typename SrcType>
|
||||
inline void CastToStringData(const Tensor* in, Tensor* out, const TensorShape& shape) {
|
||||
const int64_t len = shape.Size();
|
||||
ORT_ENFORCE(len > 0);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
if (std::is_floating_point<SrcType>::value && std::isnan(in->Data<SrcType>()[i])) {
|
||||
out->MutableData<std::string>()[i] = "NaN";
|
||||
} else if (std::is_floating_point<SrcType>::value && std::isinf(in->Data<SrcType>()[i])) {
|
||||
if (in->Data<SrcType>()[i] < std::numeric_limits<SrcType>::lowest()) {
|
||||
out->MutableData<std::string>()[i] = "-INF";
|
||||
} else {
|
||||
out->MutableData<std::string>()[i] = "INF";
|
||||
}
|
||||
} else {
|
||||
std::ostringstream convert;
|
||||
convert << in->Data<SrcType>()[i];
|
||||
out->MutableData<std::string>()[i] = convert.str();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstType>
|
||||
inline void CastFromStringData(const Tensor* in, Tensor* out, const TensorShape& shape) {
|
||||
if (std::is_same<DstType, std::string>::value) return;
|
||||
const int64_t len = shape.Size();
|
||||
ORT_ENFORCE(len > 0);
|
||||
if (std::is_same<DstType, float>::value) {
|
||||
float* mutable_data = out->MutableData<float>();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
mutable_data[i] = std::stof(in->Data<std::string>()[i]);
|
||||
}
|
||||
} else if (std::is_same<DstType, double>::value) {
|
||||
double* mutable_data = out->MutableData<double>();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
mutable_data[i] = std::stod(in->Data<std::string>()[i]);
|
||||
}
|
||||
} else if (std::is_same<DstType, int8_t>::value) {
|
||||
int8_t* mutable_data = out->MutableData<int8_t>();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
int temp_i = std::stoi(in->Data<std::string>()[i]);
|
||||
mutable_data[i] = static_cast<int8_t>(temp_i);
|
||||
}
|
||||
} else if (std::is_same<DstType, uint8_t>::value) {
|
||||
uint8_t* mutable_data = out->MutableData<uint8_t>();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
unsigned long temp_ui = std::stoul(in->Data<std::string>()[i]);
|
||||
mutable_data[i] = static_cast<uint8_t>(temp_ui);
|
||||
}
|
||||
} else if (std::is_same<DstType, int16_t>::value) {
|
||||
int16_t* mutable_data = out->MutableData<int16_t>();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
int temp_i = std::stoi(in->Data<std::string>()[i]);
|
||||
mutable_data[i] = static_cast<int16_t>(temp_i);
|
||||
}
|
||||
} else if (std::is_same<DstType, uint16_t>::value) {
|
||||
uint16_t* mutable_data = out->MutableData<uint16_t>();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
unsigned long temp_ui = std::stoul(in->Data<std::string>()[i]);
|
||||
mutable_data[i] = static_cast<uint16_t>(temp_ui);
|
||||
}
|
||||
} else if (std::is_same<DstType, int32_t>::value) {
|
||||
int32_t* mutable_data = out->MutableData<int32_t>();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
mutable_data[i] = std::stol(in->Data<std::string>()[i]);
|
||||
}
|
||||
} else if (std::is_same<DstType, uint32_t>::value) {
|
||||
uint32_t* mutable_data = out->MutableData<uint32_t>();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
mutable_data[i] = std::stoul(in->Data<std::string>()[i]);
|
||||
}
|
||||
} else if (std::is_same<DstType, int64_t>::value) {
|
||||
int64_t* mutable_data = out->MutableData<int64_t>();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
mutable_data[i] = std::stoll(in->Data<std::string>()[i]);
|
||||
}
|
||||
} else if (std::is_same<DstType, uint64_t>::value) {
|
||||
uint64_t* mutable_data = out->MutableData<uint64_t>();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
mutable_data[i] = std::stoull(in->Data<std::string>()[i]);
|
||||
}
|
||||
} else {
|
||||
ORT_THROW("Unsupported type in cast op: from String to ", typeid(DstType).name());
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
||||
template <typename T>
|
||||
class Cast final : public OpKernel {
|
||||
public:
|
||||
Cast(const OpKernelInfo& info) : OpKernel(info) {
|
||||
int64_t to;
|
||||
Status status = info.GetAttr("to", &to);
|
||||
ORT_ENFORCE(status.IsOK(), "Attribute to is not set.");
|
||||
to_ = gsl::narrow_cast<ONNX_NAMESPACE::TensorProto_DataType>(to);
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
template <typename SrcType,
|
||||
typename DstType>
|
||||
void CastData(const Tensor* in, Tensor* out, const TensorShape& shape) const {
|
||||
::onnxruntime::CastData<SrcType, DstType>(in, out, shape);
|
||||
}
|
||||
|
||||
template <typename SrcType,
|
||||
typename DstType>
|
||||
Status CastFloat16Data(const Tensor* in, Tensor* out, const TensorShape& shape, OpKernelContext* context) const {
|
||||
AllocatorPtr allocator;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
|
||||
::onnxruntime::CastFloat16Data<SrcType, DstType>(in, out, shape, allocator);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename SrcType>
|
||||
Status CastToStringData(const Tensor* in, Tensor* out, const TensorShape& shape) const {
|
||||
::onnxruntime::CastToStringData<SrcType>(in, out, shape);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename DstType>
|
||||
Status CastFromStringData(const Tensor* in, Tensor* out, const TensorShape& shape) const {
|
||||
::onnxruntime::CastFromStringData<DstType>(in, out, shape);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ONNX_NAMESPACE::TensorProto_DataType to_;
|
||||
};
|
||||
|
||||
|
||||
const std::vector<MLDataType> castOpTypeConstraints{
|
||||
DataTypeImpl::GetTensorType<bool>(),
|
||||
DataTypeImpl::GetTensorType<float>(),
|
||||
|
|
@ -20,12 +209,14 @@ const std::vector<MLDataType> castOpTypeConstraints{
|
|||
DataTypeImpl::GetTensorType<int16_t>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>(),
|
||||
DataTypeImpl::GetTensorType<MLFloat16>()};
|
||||
DataTypeImpl::GetTensorType<MLFloat16>(),
|
||||
DataTypeImpl::GetTensorType<std::string>()};
|
||||
|
||||
#define ADD_FROM_CAST_OP(in_type) \
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \
|
||||
Cast, \
|
||||
6, \
|
||||
9, \
|
||||
in_type, \
|
||||
KernelDefBuilder().TypeConstraint("T1", DataTypeImpl::GetTensorType<in_type>()).TypeConstraint("T2", castOpTypeConstraints), \
|
||||
Cast<in_type>); \
|
||||
|
|
@ -33,7 +224,6 @@ const std::vector<MLDataType> castOpTypeConstraints{
|
|||
template <> \
|
||||
Status Cast<in_type>::Compute(OpKernelContext* context) const { \
|
||||
const Tensor* X = context->Input<Tensor>(0); \
|
||||
if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); \
|
||||
const TensorShape& shape = X->Shape(); \
|
||||
Tensor* Y = context->Output(0, TensorShape(shape)); \
|
||||
\
|
||||
|
|
@ -80,11 +270,12 @@ const std::vector<MLDataType> castOpTypeConstraints{
|
|||
} \
|
||||
break; \
|
||||
case TensorProto_DataType_STRING: \
|
||||
ORT_THROW("Casting to and from strings is not supported yet."); /*break;*/ \
|
||||
CastToStringData<in_type>(X, Y, shape); \
|
||||
break; \
|
||||
case TensorProto_DataType_UNDEFINED: \
|
||||
ORT_THROW("Cast op must have 'to' argument of type DataType"); /*break;*/ \
|
||||
ORT_THROW("Cast op must have 'to' argument of type DataType"); /*break;*/ \
|
||||
default: \
|
||||
ORT_THROW("Unexpected 'to' argument value: ", to_); \
|
||||
ORT_THROW("Unexpected 'to' argument value: ", to_); \
|
||||
} \
|
||||
return Status::OK(); \
|
||||
}
|
||||
|
|
@ -101,9 +292,10 @@ ADD_FROM_CAST_OP(bool);
|
|||
ADD_FROM_CAST_OP(float);
|
||||
ADD_FROM_CAST_OP(double);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
|
||||
Cast,
|
||||
6,
|
||||
9,
|
||||
MLFloat16,
|
||||
KernelDefBuilder().TypeConstraint("T1", DataTypeImpl::GetTensorType<MLFloat16>()).TypeConstraint("T2", castOpTypeConstraints),
|
||||
Cast<MLFloat16>);
|
||||
|
|
@ -111,7 +303,6 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
|||
template <>
|
||||
Status Cast<MLFloat16>::Compute(OpKernelContext* context) const {
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
|
||||
const TensorShape& shape = X->Shape();
|
||||
Tensor* Y = context->Output(0, TensorShape(shape));
|
||||
Status st;
|
||||
|
|
@ -144,14 +335,14 @@ Status Cast<MLFloat16>::Compute(OpKernelContext* context) const {
|
|||
CastData<MLFloat16, float>(X, Y, shape);
|
||||
break;
|
||||
case TensorProto_DataType_FLOAT16: {
|
||||
auto X_type = X->DataType();
|
||||
const void* source = X->DataRaw(X_type);
|
||||
void* target = Y->MutableDataRaw(X_type);
|
||||
// if source and target pointers are not equal, we need to copy the data.
|
||||
if (target != source) {
|
||||
memcpy(target, source, shape.Size() * X_type->Size());
|
||||
}
|
||||
st = Status::OK();
|
||||
auto X_type = X->DataType();
|
||||
const void* source = X->DataRaw(X_type);
|
||||
void* target = Y->MutableDataRaw(X_type);
|
||||
// if source and target pointers are not equal, we need to copy the data.
|
||||
if (target != source) {
|
||||
memcpy(target, source, shape.Size() * X_type->Size());
|
||||
}
|
||||
st = Status::OK();
|
||||
break;
|
||||
}
|
||||
case TensorProto_DataType_DOUBLE:
|
||||
|
|
@ -170,4 +361,57 @@ Status Cast<MLFloat16>::Compute(OpKernelContext* context) const {
|
|||
return st;
|
||||
}
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_KERNEL(
|
||||
Cast,
|
||||
9,
|
||||
string,
|
||||
KernelDefBuilder().TypeConstraint("T1", DataTypeImpl::GetTensorType<std::string>()).TypeConstraint("T2", castOpTypeConstraints),
|
||||
Cast<std::string>);
|
||||
|
||||
template <>
|
||||
Status Cast<std::string>::Compute(OpKernelContext* context) const {
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL,
|
||||
"Input is missing. The operator Cast expects one and only one input");
|
||||
const TensorShape& shape = X->Shape();
|
||||
Tensor* Y = context->Output(0, TensorShape(shape));
|
||||
Status st;
|
||||
switch (to_) {
|
||||
case TensorProto_DataType_INT16:
|
||||
st = CastFromStringData<int16_t>(X, Y, shape);
|
||||
break;
|
||||
case TensorProto_DataType_INT32:
|
||||
st = CastFromStringData<int32_t>(X, Y, shape);
|
||||
break;
|
||||
case TensorProto_DataType_INT64:
|
||||
st = CastFromStringData<int64_t>(X, Y, shape);
|
||||
break;
|
||||
case TensorProto_DataType_UINT8:
|
||||
st = CastFromStringData<uint8_t>(X, Y, shape);
|
||||
break;
|
||||
case TensorProto_DataType_UINT16:
|
||||
st = CastFromStringData<uint16_t>(X, Y, shape);
|
||||
break;
|
||||
case TensorProto_DataType_UINT32:
|
||||
st = CastFromStringData<uint32_t>(X, Y, shape);
|
||||
break;
|
||||
case TensorProto_DataType_UINT64:
|
||||
st = CastFromStringData<uint64_t>(X, Y, shape);
|
||||
break;
|
||||
case TensorProto_DataType_FLOAT:
|
||||
st = CastFromStringData<float>(X, Y, shape);
|
||||
break;
|
||||
case TensorProto_DataType_DOUBLE:
|
||||
st = CastFromStringData<double>(X, Y, shape);
|
||||
break;
|
||||
case TensorProto_DataType_INT8:
|
||||
st = CastFromStringData<int8_t>(X, Y, shape);
|
||||
break;
|
||||
case TensorProto_DataType_UNDEFINED:
|
||||
ORT_THROW("Cast op must have 'to' argument of type DataType");
|
||||
default:
|
||||
ORT_THROW("Unexpected 'to' argument value: ", to_);
|
||||
}
|
||||
return st;
|
||||
}
|
||||
} //namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,100 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/util/math.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include "Eigen/src/Core/arch/CUDA/Half.h"
|
||||
|
||||
#if defined(USE_MLAS) && defined(_M_AMD64)
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
template <typename SrcType,
|
||||
typename DstType>
|
||||
inline void CastData(const Tensor* in, Tensor* out, const TensorShape& shape) {
|
||||
auto shape_size = shape.Size();
|
||||
auto in_vector = ConstEigenVectorMap<SrcType>(in->template Data<SrcType>(), shape_size);
|
||||
auto output_vector = EigenVectorMap<DstType>(out->template MutableData<DstType>(), shape_size);
|
||||
output_vector = in_vector.template cast<DstType>();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void CastData<float, MLFloat16>(const Tensor* in, Tensor* out, const TensorShape& shape) {
|
||||
auto out_data = out->template MutableData<MLFloat16>();
|
||||
auto shape_size = shape.Size();
|
||||
auto in_vector = ConstEigenVectorMap<float>(in->template Data<float>(), shape_size);
|
||||
auto output_vector = EigenVectorMap<Eigen::half>(static_cast<Eigen::half*>(static_cast<void*>(out_data)), shape_size);
|
||||
output_vector = in_vector.template cast<Eigen::half>();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void CastData<MLFloat16, float>(const Tensor* in, Tensor* out, const TensorShape& shape) {
|
||||
auto out_data = out->template MutableData<float>();
|
||||
auto in_data = in->template Data<MLFloat16>();
|
||||
auto shape_size = shape.Size();
|
||||
#if defined(USE_MLAS) && defined(_M_AMD64)
|
||||
MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size);
|
||||
#else
|
||||
auto in_vector = ConstEigenVectorMap<Eigen::half>(static_cast<const Eigen::half*>(static_cast<const void*>(in_data)), shape_size);
|
||||
auto output_vector = EigenVectorMap<float>(out_data, shape_size);
|
||||
output_vector = in_vector.template cast<float>();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename SrcType,
|
||||
typename DstType>
|
||||
inline void CastFloat16Data(const Tensor* in, Tensor* out, const TensorShape& shape, const AllocatorPtr& allocator) {
|
||||
ORT_ENFORCE(allocator != nullptr);
|
||||
const int64_t len = shape.Size();
|
||||
ORT_ENFORCE(len > 0);
|
||||
void* buffer = allocator->AllocArray(sizeof(float), len);
|
||||
ORT_ENFORCE(buffer);
|
||||
Tensor tmp_tensor(DataTypeImpl::GetType<float>(), shape, buffer, allocator->Info(), nullptr);
|
||||
if (std::is_same<SrcType, MLFloat16>::value) {
|
||||
CastData<MLFloat16, float>(in, &tmp_tensor, shape); // first cast to float
|
||||
CastData<float, DstType>(&tmp_tensor, out, shape); // then cast to the destination type.
|
||||
} else if (std::is_same<DstType, MLFloat16>::value) {
|
||||
CastData<SrcType, float>(in, &tmp_tensor, shape);
|
||||
CastData<float, MLFloat16>(&tmp_tensor, out, shape);
|
||||
}
|
||||
allocator->Free(buffer);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Cast final : public OpKernel {
|
||||
public:
|
||||
Cast(const OpKernelInfo& info) : OpKernel(info) {
|
||||
int64_t to;
|
||||
Status status = info.GetAttr("to", &to);
|
||||
ORT_ENFORCE(status.IsOK(), "Attribute to is not set.");
|
||||
to_ = gsl::narrow_cast<ONNX_NAMESPACE::TensorProto_DataType>(to);
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
template <typename SrcType,
|
||||
typename DstType>
|
||||
void CastData(const Tensor* in, Tensor* out, const TensorShape& shape) const {
|
||||
::onnxruntime::CastData<SrcType, DstType>(in, out, shape);
|
||||
}
|
||||
|
||||
template <typename SrcType,
|
||||
typename DstType>
|
||||
Status CastFloat16Data(const Tensor* in, Tensor* out, const TensorShape& shape, OpKernelContext* context) const {
|
||||
AllocatorPtr allocator;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
|
||||
::onnxruntime::CastFloat16Data<SrcType, DstType>(in, out, shape, allocator);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ONNX_NAMESPACE::TensorProto_DataType to_;
|
||||
};
|
||||
|
||||
} //namespace onnxruntime
|
||||
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "core/providers/cpu/tensor/cast_op.h"
|
||||
#include "core/providers/cpu/tensor/crop.h"
|
||||
#include "core/util/math.h"
|
||||
|
||||
|
|
@ -80,7 +79,7 @@ void TestCastOp(const std::initializer_list<SrcType>& input,
|
|||
int64_t toType,
|
||||
ExpectResult expect_result = ExpectResult::kExpectSuccess,
|
||||
const std::string& expected_failure_string = "") {
|
||||
OpTester test("Cast");
|
||||
OpTester test("Cast", 9);
|
||||
test.AddAttribute("to", toType);
|
||||
test.AddInput<SrcType>("input", dimensions, input);
|
||||
test.AddOutput<DstType>("output", dimensions, output);
|
||||
|
|
@ -277,6 +276,32 @@ TEST(TensorOpTest, CastFromFloat16) {
|
|||
TestCastOp(input, int64_t_data, shape, TensorProto::INT64);
|
||||
}
|
||||
|
||||
TEST(TensorOpTest, CastFromString) {
|
||||
const std::vector<int64_t> shape{2, 2, 2};
|
||||
std::initializer_list<std::string> string_data = {"-inf", "+INF", "2.0f", "3.0f", "4.0f", "5.0f", "NaN", "nan"};
|
||||
const std::initializer_list<float> float_output = {-(std::numeric_limits<float>::infinity()), std::numeric_limits<float>::infinity(), 2.0f, 3.0f, 4.0f, 5.0f, NAN, NAN};
|
||||
TestCastOp(string_data, float_output, shape, TensorProto::FLOAT);
|
||||
|
||||
std::initializer_list<std::string> int_16_string_data = {"0", "1", "2", "3", "4", "5", "-32768", "32767"};
|
||||
const std::initializer_list<int16_t> int_16_output = {0, 1, 2, 3, 4, 5, SHRT_MIN, SHRT_MAX};
|
||||
TestCastOp(int_16_string_data, int_16_output, shape, TensorProto::INT16);
|
||||
|
||||
std::initializer_list<std::string> int_64_string_data = {"0", "1", "2", "3", "4", "5", "-9223372036854775808", "9223372036854775807"};
|
||||
const std::initializer_list<int64_t> int_64_output = {0, 1, 2, 3, 4, 5, LLONG_MIN, LLONG_MAX};
|
||||
TestCastOp(int_64_string_data, int_64_output, shape, TensorProto::INT64);
|
||||
}
|
||||
|
||||
TEST(TensorOpTest, CastToString) {
|
||||
const std::vector<int64_t> shape{2, 2, 2};
|
||||
const std::initializer_list<float> float_input = {NAN, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, -std::numeric_limits<float>::infinity(), std::numeric_limits<float>::infinity()};
|
||||
std::initializer_list<std::string> string_output = {"NaN", "1", "2", "3", "4", "5", "-INF", "INF"};
|
||||
TestCastOp(float_input, string_output, shape, TensorProto::STRING);
|
||||
|
||||
std::initializer_list<std::string> int_string_data = {"0", "1", "2", "3", "4", "5", "6", "7"};
|
||||
const std::initializer_list<int16_t> int_16_input = {0, 1, 2, 3, 4, 5, 6, 7};
|
||||
TestCastOp(int_16_input, int_string_data, shape, TensorProto::STRING);
|
||||
}
|
||||
|
||||
TEST(TensorOpTest, CropBorderOnly) {
|
||||
const int N = 2, C = 1, H = 3, W = 4;
|
||||
std::vector<float> X = {1.0f, 2.0f, 3.0f, 4.0f,
|
||||
|
|
|
|||
|
|
@ -50,9 +50,11 @@ void Check<float>(const OpTester::Data& expected_data, const Tensor& output_tens
|
|||
#endif
|
||||
|
||||
for (int i = 0; i < size; ++i) {
|
||||
if (std::isinf(expected[i])) // Test infinity for equality
|
||||
if (std::isinf(expected[i])){ // Test infinity for equality
|
||||
EXPECT_EQ(expected[i], output[i]);
|
||||
else {
|
||||
} else if (std::isnan(expected[i])) {
|
||||
EXPECT_TRUE(std::isnan(output[i])) << "Expected output " << i << " to be NaN";
|
||||
} else {
|
||||
if (!has_abs_err && !has_rel_err) {
|
||||
// the default for existing tests
|
||||
EXPECT_NEAR(expected[i], output[i], threshold) << "provider_type: " << provider_type;
|
||||
|
|
|
|||
|
|
@ -37,8 +37,9 @@ else
|
|||
#Install ONNX
|
||||
#5af210ca8a1c73aa6bae8754c9346ec54d0a756e is v1.2.3
|
||||
#bae6333e149a59a3faa9c4d9c44974373dcf5256 is v1.3.0
|
||||
#dbf3581835e3a05716e10587511d7ab3b2cdc386 is v1.3.0 latest
|
||||
for onnx_version in "5af210ca8a1c73aa6bae8754c9346ec54d0a756e" "bae6333e149a59a3faa9c4d9c44974373dcf5256" "dbf3581835e3a05716e10587511d7ab3b2cdc386"; do
|
||||
#9e55ace55aad1ada27516038dfbdc66a8a0763db is v1.4.1
|
||||
#2896c77cfc628f18b6ca6b28e3a380807fa00f53 is v1.4.1 latest
|
||||
for onnx_version in "5af210ca8a1c73aa6bae8754c9346ec54d0a756e" "bae6333e149a59a3faa9c4d9c44974373dcf5256" "9e55ace55aad1ada27516038dfbdc66a8a0763db" "2896c77cfc628f18b6ca6b28e3a380807fa00f53"; do
|
||||
if [ -z ${lastest_onnx_version+x} ]; then
|
||||
echo "first pass";
|
||||
else
|
||||
|
|
|
|||
Loading…
Reference in a new issue