mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Reduce DataTypeImpl binary size (#9783)
* Reduce the number of virtual methods in DataTypeImpl to reduce binary size. Refactor some helpers to reduce the amount of templatized code.
This commit is contained in:
parent
f390347c11
commit
dc1724b0e2
2 changed files with 101 additions and 125 deletions
|
|
@ -33,7 +33,7 @@ namespace onnxruntime {
|
|||
|
||||
#if !defined(DISABLE_ML_OPS)
|
||||
|
||||
//maps (only used by ML ops)
|
||||
// maps (only used by ML ops)
|
||||
using MapStringToString = std::map<std::string, std::string>;
|
||||
using MapStringToInt64 = std::map<std::string, int64_t>;
|
||||
using MapStringToFloat = std::map<std::string, float>;
|
||||
|
|
@ -43,7 +43,7 @@ using MapInt64ToInt64 = std::map<int64_t, int64_t>;
|
|||
using MapInt64ToFloat = std::map<int64_t, float>;
|
||||
using MapInt64ToDouble = std::map<int64_t, double>;
|
||||
|
||||
//vectors/sequences
|
||||
// vectors/sequences
|
||||
using VectorMapStringToFloat = std::vector<MapStringToFloat>;
|
||||
using VectorMapInt64ToFloat = std::vector<MapInt64ToFloat>;
|
||||
|
||||
|
|
@ -78,6 +78,22 @@ using CreateFunc = void* (*)();
|
|||
*
|
||||
*/
|
||||
class DataTypeImpl {
|
||||
public:
|
||||
enum class GeneralType {
|
||||
kInvalid = 0,
|
||||
kNonTensor = 1,
|
||||
kTensor = 2,
|
||||
kTensorSequence = 3,
|
||||
kSparseTensor = 4,
|
||||
kOptional = 5
|
||||
};
|
||||
|
||||
GeneralType type_;
|
||||
size_t size_;
|
||||
|
||||
protected:
|
||||
DataTypeImpl(GeneralType type, size_t size) : type_{type}, size_{size} {}
|
||||
|
||||
public:
|
||||
virtual ~DataTypeImpl() = default;
|
||||
|
||||
|
|
@ -90,7 +106,7 @@ class DataTypeImpl {
|
|||
*/
|
||||
virtual bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const = 0;
|
||||
|
||||
virtual size_t Size() const = 0;
|
||||
size_t Size() const { return size_; }
|
||||
|
||||
virtual DeleteFunc GetDeleteFunc() const = 0;
|
||||
|
||||
|
|
@ -102,20 +118,20 @@ class DataTypeImpl {
|
|||
*/
|
||||
virtual const ONNX_NAMESPACE::TypeProto* GetTypeProto() const = 0;
|
||||
|
||||
virtual bool IsTensorType() const {
|
||||
return false;
|
||||
bool IsTensorType() const {
|
||||
return type_ == GeneralType::kTensor;
|
||||
}
|
||||
|
||||
virtual bool IsTensorSequenceType() const {
|
||||
return false;
|
||||
bool IsTensorSequenceType() const {
|
||||
return type_ == GeneralType::kTensorSequence;
|
||||
}
|
||||
|
||||
virtual bool IsSparseTensorType() const {
|
||||
return false;
|
||||
bool IsSparseTensorType() const {
|
||||
return type_ == GeneralType::kSparseTensor;
|
||||
}
|
||||
|
||||
virtual bool IsOptionalType() const {
|
||||
return false;
|
||||
bool IsOptionalType() const {
|
||||
return type_ == GeneralType::kOptional;
|
||||
}
|
||||
|
||||
// Returns this if this is of tensor-type and null otherwise
|
||||
|
|
@ -383,17 +399,17 @@ struct SetMapTypes {
|
|||
void CopyMutableSeqElement(const ONNX_NAMESPACE::TypeProto&,
|
||||
ONNX_NAMESPACE::TypeProto&);
|
||||
|
||||
template <typename T>
|
||||
struct SetSequenceType {
|
||||
static void Set(ONNX_NAMESPACE::TypeProto& proto) {
|
||||
// helper to create TypeProto with minimal binary size impact
|
||||
struct SequenceTypeHelper {
|
||||
template <typename T>
|
||||
static const onnx::TypeProto* GetElemType() {
|
||||
MLDataType dt = GetMLDataType<T, IsTensorContainedType<T>::value>::Get();
|
||||
const auto* elem_proto = dt->GetTypeProto();
|
||||
#ifdef ORT_NO_RTTI
|
||||
return elem_proto; // check for nullptr is in Set
|
||||
}
|
||||
|
||||
static void Set(const onnx::TypeProto* elem_proto, ONNX_NAMESPACE::TypeProto& proto) {
|
||||
ORT_ENFORCE(elem_proto != nullptr, "expected a registered ONNX type");
|
||||
#else
|
||||
ORT_ENFORCE(elem_proto != nullptr, typeid(T).name(),
|
||||
" expected to be a registered ONNX type");
|
||||
#endif
|
||||
CopyMutableSeqElement(*elem_proto, proto);
|
||||
}
|
||||
};
|
||||
|
|
@ -403,9 +419,10 @@ struct SetSequenceType {
|
|||
void CopyMutableOptionalElement(const ONNX_NAMESPACE::TypeProto&,
|
||||
ONNX_NAMESPACE::TypeProto&);
|
||||
|
||||
template <typename T, typename elemT>
|
||||
struct SetOptionalType {
|
||||
static void Set(ONNX_NAMESPACE::TypeProto& proto) {
|
||||
// helper to create TypeProto with minimal binary size impact
|
||||
struct OptionalTypeHelper {
|
||||
template <typename T, typename elemT>
|
||||
static const onnx::TypeProto* GetElemType() {
|
||||
const onnx::TypeProto* elem_proto = nullptr;
|
||||
if (std::is_same<T, Tensor>::value) {
|
||||
MLDataType dt = DataTypeImpl::GetTensorType<elemT>();
|
||||
|
|
@ -413,17 +430,13 @@ struct SetOptionalType {
|
|||
} else if (std::is_same<T, TensorSeq>::value) {
|
||||
MLDataType dt = DataTypeImpl::GetSequenceTensorType<elemT>();
|
||||
elem_proto = dt->GetTypeProto();
|
||||
} else {
|
||||
// Will not reach here
|
||||
ORT_ENFORCE(false, "Unsupported type for optional type");
|
||||
}
|
||||
|
||||
#ifdef ORT_NO_RTTI
|
||||
ORT_ENFORCE(elem_proto != nullptr, "expected a registered ORT type");
|
||||
#else
|
||||
ORT_ENFORCE(elem_proto != nullptr, typeid(T).name(),
|
||||
" expected to be a registered ORT type");
|
||||
#endif
|
||||
return elem_proto; // check for nullptr is in Set
|
||||
}
|
||||
|
||||
static void Set(const onnx::TypeProto* elem_proto, ONNX_NAMESPACE::TypeProto& proto) {
|
||||
ORT_ENFORCE(elem_proto != nullptr, " unregistered or unsupported ORT type for Optional");
|
||||
CopyMutableOptionalElement(*elem_proto, proto);
|
||||
}
|
||||
};
|
||||
|
|
@ -445,16 +458,10 @@ class TensorTypeBase : public DataTypeImpl {
|
|||
/// where TypeProto was created ad-hoc and not queried from MLDataType
|
||||
bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
|
||||
|
||||
bool IsTensorType() const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
const TensorTypeBase* AsTensorType() const override {
|
||||
return this;
|
||||
}
|
||||
|
||||
size_t Size() const override;
|
||||
|
||||
DeleteFunc GetDeleteFunc() const override;
|
||||
|
||||
const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
|
||||
|
|
@ -464,8 +471,7 @@ class TensorTypeBase : public DataTypeImpl {
|
|||
ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
|
||||
}
|
||||
|
||||
TensorTypeBase(const TensorTypeBase&) = delete;
|
||||
TensorTypeBase& operator=(const TensorTypeBase&) = delete;
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorTypeBase);
|
||||
|
||||
protected:
|
||||
ONNX_NAMESPACE::TypeProto& MutableTypeProto();
|
||||
|
|
@ -527,10 +533,6 @@ class DisabledTypeBase : public DataTypeImpl {
|
|||
return false;
|
||||
}
|
||||
|
||||
size_t Size() const override {
|
||||
ORT_THROW("Type is disabled in this build.");
|
||||
}
|
||||
|
||||
DeleteFunc GetDeleteFunc() const override {
|
||||
ORT_THROW("Type is disabled in this build.");
|
||||
}
|
||||
|
|
@ -544,7 +546,7 @@ class DisabledTypeBase : public DataTypeImpl {
|
|||
// This must work
|
||||
ONNX_NAMESPACE::TypeProto& MutableTypeProto();
|
||||
|
||||
DisabledTypeBase();
|
||||
DisabledTypeBase(DataTypeImpl::GeneralType type, size_t size);
|
||||
~DisabledTypeBase() override;
|
||||
|
||||
private:
|
||||
|
|
@ -560,18 +562,12 @@ class SparseTensorTypeBase : public DataTypeImpl {
|
|||
public:
|
||||
static MLDataType Type();
|
||||
|
||||
bool IsSparseTensorType() const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
const SparseTensorTypeBase* AsSparseTensorType() const override {
|
||||
return this;
|
||||
}
|
||||
|
||||
bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
|
||||
|
||||
size_t Size() const override;
|
||||
|
||||
DeleteFunc GetDeleteFunc() const override;
|
||||
|
||||
const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
|
||||
|
|
@ -581,8 +577,7 @@ class SparseTensorTypeBase : public DataTypeImpl {
|
|||
ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
|
||||
}
|
||||
|
||||
SparseTensorTypeBase(const SparseTensorTypeBase&) = delete;
|
||||
SparseTensorTypeBase& operator=(const SparseTensorTypeBase&) = delete;
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SparseTensorTypeBase);
|
||||
|
||||
protected:
|
||||
ONNX_NAMESPACE::TypeProto& MutableTypeProto();
|
||||
|
|
@ -624,21 +619,12 @@ class OptionalTypeBase : public DataTypeImpl {
|
|||
public:
|
||||
static MLDataType Type();
|
||||
|
||||
bool IsOptionalType() const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
const OptionalTypeBase* AsOptionalType() const override {
|
||||
return this;
|
||||
}
|
||||
|
||||
bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
|
||||
|
||||
size_t Size() const override {
|
||||
// should never reach here.
|
||||
ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
|
||||
}
|
||||
|
||||
DeleteFunc GetDeleteFunc() const override {
|
||||
// should never reach here.
|
||||
ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
|
||||
|
|
@ -699,20 +685,26 @@ class OptionalType :
|
|||
#endif
|
||||
|
||||
private:
|
||||
OptionalType() {
|
||||
data_types_internal::SetOptionalType<T, elemT>::Set(MutableTypeProto());
|
||||
#if !defined(DISABLE_OPTIONAL_TYPE)
|
||||
OptionalType()
|
||||
#else
|
||||
OptionalType() : DisabledTypeBase { DataTypeImpl::GeneralType::kOptional, 0 }
|
||||
#endif
|
||||
{
|
||||
using namespace data_types_internal;
|
||||
OptionalTypeHelper::Set(OptionalTypeHelper::GetElemType<T, elemT>(), MutableTypeProto());
|
||||
}
|
||||
};
|
||||
}; // namespace onnxruntime
|
||||
|
||||
/**
|
||||
* \brief Provide a specialization for your C++ Non-tensor type
|
||||
* so your implementation FromDataTypeContainer/ToDataTypeContainer
|
||||
* functions correctly. Otherwise you get a default implementation
|
||||
* which may not be what you need/want.
|
||||
*
|
||||
* This class is used to create OrtValue, fetch data from OrtValue via
|
||||
* C/C++ APIs
|
||||
*/
|
||||
* \brief Provide a specialization for your C++ Non-tensor type
|
||||
* so your implementation FromDataTypeContainer/ToDataTypeContainer
|
||||
* functions correctly. Otherwise you get a default implementation
|
||||
* which may not be what you need/want.
|
||||
*
|
||||
* This class is used to create OrtValue, fetch data from OrtValue via
|
||||
* C/C++ APIs
|
||||
*/
|
||||
template <class T>
|
||||
struct NonTensorTypeConverter {
|
||||
static void FromContainer(MLDataType /*dtype*/, const void* /*data*/, size_t /*data_size*/, OrtValue& /*output*/) {
|
||||
|
|
@ -728,8 +720,6 @@ struct NonTensorTypeConverter {
|
|||
*/
|
||||
class NonTensorTypeBase : public DataTypeImpl {
|
||||
public:
|
||||
size_t Size() const override = 0;
|
||||
|
||||
DeleteFunc GetDeleteFunc() const override = 0;
|
||||
|
||||
virtual CreateFunc GetCreateFunc() const = 0;
|
||||
|
|
@ -766,7 +756,7 @@ class NonTensorTypeBase : public DataTypeImpl {
|
|||
NonTensorTypeBase& operator=(const NonTensorTypeBase&) = delete;
|
||||
|
||||
protected:
|
||||
NonTensorTypeBase();
|
||||
NonTensorTypeBase(size_t size);
|
||||
~NonTensorTypeBase() override;
|
||||
|
||||
ONNX_NAMESPACE::TypeProto& MutableTypeProto();
|
||||
|
|
@ -791,10 +781,6 @@ class NonTensorType : public NonTensorTypeBase {
|
|||
}
|
||||
|
||||
public:
|
||||
size_t Size() const override {
|
||||
return sizeof(T);
|
||||
}
|
||||
|
||||
DeleteFunc GetDeleteFunc() const override {
|
||||
return &Delete;
|
||||
}
|
||||
|
|
@ -804,7 +790,7 @@ class NonTensorType : public NonTensorTypeBase {
|
|||
}
|
||||
|
||||
protected:
|
||||
NonTensorType() = default;
|
||||
NonTensorType() : NonTensorTypeBase(sizeof(T)) {}
|
||||
};
|
||||
|
||||
#if !defined(DISABLE_ML_OPS)
|
||||
|
|
@ -858,7 +844,9 @@ class SequenceType : public NonTensorType<CPPType> {
|
|||
|
||||
private:
|
||||
SequenceType() {
|
||||
data_types_internal::SetSequenceType<typename CPPType::value_type>::Set(this->MutableTypeProto());
|
||||
using namespace data_types_internal;
|
||||
SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType<typename CPPType::value_type>(),
|
||||
this->MutableTypeProto());
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -873,10 +861,6 @@ class SequenceTensorTypeBase : public DataTypeImpl {
|
|||
|
||||
bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
|
||||
|
||||
bool IsTensorSequenceType() const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
const SequenceTensorTypeBase* AsSequenceTensorType() const override {
|
||||
return this;
|
||||
}
|
||||
|
|
@ -886,8 +870,6 @@ class SequenceTensorTypeBase : public DataTypeImpl {
|
|||
ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
|
||||
}
|
||||
|
||||
size_t Size() const override;
|
||||
|
||||
DeleteFunc GetDeleteFunc() const override;
|
||||
|
||||
const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
|
||||
|
|
@ -931,7 +913,8 @@ class SequenceTensorType : public SequenceTensorTypeBase {
|
|||
|
||||
private:
|
||||
SequenceTensorType() {
|
||||
data_types_internal::SetSequenceType<TensorElemType>::Set(this->MutableTypeProto());
|
||||
using namespace data_types_internal;
|
||||
SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType<TensorElemType>(), MutableTypeProto());
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -998,11 +981,8 @@ class PrimitiveDataTypeBase : public DataTypeImpl {
|
|||
}
|
||||
|
||||
protected:
|
||||
PrimitiveDataTypeBase() = default;
|
||||
|
||||
void SetDataType(int32_t data_type) {
|
||||
data_type_ = data_type;
|
||||
}
|
||||
PrimitiveDataTypeBase(size_t size, int32_t data_type)
|
||||
: DataTypeImpl{GeneralType::kNonTensor, size}, data_type_{data_type} {}
|
||||
|
||||
private:
|
||||
int32_t data_type_;
|
||||
|
|
@ -1015,7 +995,7 @@ class PrimitiveDataTypeBase : public DataTypeImpl {
|
|||
*
|
||||
* \param T - primitive data type
|
||||
*
|
||||
*/
|
||||
*/
|
||||
template <typename T>
|
||||
class PrimitiveDataType : public PrimitiveDataTypeBase {
|
||||
private:
|
||||
|
|
@ -1026,17 +1006,14 @@ class PrimitiveDataType : public PrimitiveDataTypeBase {
|
|||
public:
|
||||
static MLDataType Type();
|
||||
|
||||
size_t Size() const override {
|
||||
return sizeof(T);
|
||||
}
|
||||
|
||||
DeleteFunc GetDeleteFunc() const override {
|
||||
return &Delete;
|
||||
}
|
||||
|
||||
private:
|
||||
PrimitiveDataType() {
|
||||
this->SetDataType(data_types_internal::TensorElementTypeSetter<T>::GetElementType());
|
||||
PrimitiveDataType()
|
||||
: PrimitiveDataTypeBase{sizeof(T),
|
||||
data_types_internal::TensorElementTypeSetter<T>::GetElementType()} {
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -59,10 +59,10 @@ MLDataType DataTypeImpl::GetType<TensorSeq>() {
|
|||
return SequenceTensorTypeBase::Type();
|
||||
}
|
||||
|
||||
//static bool IsTensorTypeScalar(const ONNX_NAMESPACE::TypeProto_Tensor& tensor_type_proto) {
|
||||
// int sz = tensor_type_proto.shape().dim_size();
|
||||
// return sz == 0 || sz == 1;
|
||||
//}
|
||||
// static bool IsTensorTypeScalar(const ONNX_NAMESPACE::TypeProto_Tensor& tensor_type_proto) {
|
||||
// int sz = tensor_type_proto.shape().dim_size();
|
||||
// return sz == 0 || sz == 1;
|
||||
// }
|
||||
|
||||
namespace data_types_internal {
|
||||
|
||||
|
|
@ -368,15 +368,13 @@ const ONNX_NAMESPACE::TypeProto* TensorTypeBase::GetTypeProto() const {
|
|||
return impl_->GetProto();
|
||||
}
|
||||
|
||||
TensorTypeBase::TensorTypeBase() : impl_(new Impl()) {}
|
||||
TensorTypeBase::TensorTypeBase()
|
||||
: DataTypeImpl{DataTypeImpl::GeneralType::kTensor, sizeof(Tensor)},
|
||||
impl_(new Impl()) {}
|
||||
TensorTypeBase::~TensorTypeBase() {
|
||||
delete impl_;
|
||||
}
|
||||
|
||||
size_t TensorTypeBase::Size() const {
|
||||
return sizeof(Tensor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void Delete(void* p) {
|
||||
delete static_cast<T*>(p);
|
||||
|
|
@ -417,7 +415,10 @@ MLDataType TensorTypeBase::Type() {
|
|||
struct SparseTensorTypeBase::Impl : public data_types_internal::TypeProtoImpl {
|
||||
};
|
||||
|
||||
SparseTensorTypeBase::SparseTensorTypeBase() : impl_(new Impl()) {}
|
||||
SparseTensorTypeBase::SparseTensorTypeBase()
|
||||
: DataTypeImpl{DataTypeImpl::GeneralType::kSparseTensor, sizeof(SparseTensor)},
|
||||
impl_(new Impl()) {}
|
||||
|
||||
SparseTensorTypeBase::~SparseTensorTypeBase() {
|
||||
delete impl_;
|
||||
}
|
||||
|
|
@ -437,10 +438,6 @@ bool SparseTensorTypeBase::IsCompatible(const ONNX_NAMESPACE::TypeProto& type_pr
|
|||
return data_types_internal::IsCompatible(thisProto->sparse_tensor_type(), type_proto.sparse_tensor_type());
|
||||
}
|
||||
|
||||
size_t SparseTensorTypeBase::Size() const {
|
||||
return sizeof(SparseTensor);
|
||||
}
|
||||
|
||||
DeleteFunc SparseTensorTypeBase::GetDeleteFunc() const {
|
||||
return &Delete<SparseTensor>;
|
||||
}
|
||||
|
|
@ -464,7 +461,9 @@ MLDataType SparseTensorTypeBase::Type() {
|
|||
struct SequenceTensorTypeBase::Impl : public data_types_internal::TypeProtoImpl {
|
||||
};
|
||||
|
||||
SequenceTensorTypeBase::SequenceTensorTypeBase() : impl_(new Impl()) {}
|
||||
SequenceTensorTypeBase::SequenceTensorTypeBase()
|
||||
: DataTypeImpl{DataTypeImpl::GeneralType::kTensorSequence, sizeof(TensorSeq)},
|
||||
impl_(new Impl()) {}
|
||||
|
||||
SequenceTensorTypeBase::~SequenceTensorTypeBase() {
|
||||
delete impl_;
|
||||
|
|
@ -489,10 +488,6 @@ bool SequenceTensorTypeBase::IsCompatible(const ONNX_NAMESPACE::TypeProto& type_
|
|||
return data_types_internal::IsCompatible(thisProto->sequence_type(), type_proto.sequence_type());
|
||||
}
|
||||
|
||||
size_t SequenceTensorTypeBase::Size() const {
|
||||
return sizeof(TensorSeq);
|
||||
}
|
||||
|
||||
DeleteFunc SequenceTensorTypeBase::GetDeleteFunc() const {
|
||||
return &Delete<TensorSeq>;
|
||||
}
|
||||
|
|
@ -516,7 +511,8 @@ MLDataType SequenceTensorTypeBase::Type() {
|
|||
struct OptionalTypeBase::Impl : public data_types_internal::TypeProtoImpl {
|
||||
};
|
||||
|
||||
OptionalTypeBase::OptionalTypeBase() : impl_(new Impl()) {}
|
||||
OptionalTypeBase::OptionalTypeBase() : DataTypeImpl{DataTypeImpl::GeneralType::kOptional, 0},
|
||||
impl_(new Impl()) {}
|
||||
|
||||
OptionalTypeBase::~OptionalTypeBase() {
|
||||
delete impl_;
|
||||
|
|
@ -557,7 +553,8 @@ MLDataType OptionalTypeBase::Type() {
|
|||
struct DisabledTypeBase::Impl : public data_types_internal::TypeProtoImpl {
|
||||
};
|
||||
|
||||
DisabledTypeBase::DisabledTypeBase() : impl_(new Impl()) {}
|
||||
DisabledTypeBase::DisabledTypeBase(DataTypeImpl::GeneralType type, size_t size)
|
||||
: DataTypeImpl{type, size}, impl_(new Impl()) {}
|
||||
|
||||
DisabledTypeBase::~DisabledTypeBase() {
|
||||
delete impl_;
|
||||
|
|
@ -572,7 +569,7 @@ ONNX_NAMESPACE::TypeProto& DisabledTypeBase::MutableTypeProto() {
|
|||
}
|
||||
|
||||
MLDataType DisabledTypeBase::Type() {
|
||||
static DisabledTypeBase disabled_base;
|
||||
static DisabledTypeBase disabled_base{GeneralType::kInvalid, 0};
|
||||
return &disabled_base;
|
||||
}
|
||||
#endif
|
||||
|
|
@ -580,7 +577,9 @@ MLDataType DisabledTypeBase::Type() {
|
|||
/// NoTensorTypeBase
|
||||
struct NonTensorTypeBase::Impl : public data_types_internal::TypeProtoImpl {};
|
||||
|
||||
NonTensorTypeBase::NonTensorTypeBase() : impl_(new Impl()) {
|
||||
NonTensorTypeBase::NonTensorTypeBase(size_t size)
|
||||
: DataTypeImpl{DataTypeImpl::GeneralType::kNonTensor, size},
|
||||
impl_(new Impl()) {
|
||||
}
|
||||
|
||||
NonTensorTypeBase::~NonTensorTypeBase() {
|
||||
|
|
@ -1037,8 +1036,8 @@ MLDataType DataTypeImpl::TypeFromProto(const ONNX_NAMESPACE::TypeProto& proto) {
|
|||
return type;
|
||||
}
|
||||
|
||||
//Below are the types the we need to execute the runtime
|
||||
//They are not compatible with TypeProto in ONNX.
|
||||
// Below are the types the we need to execute the runtime
|
||||
// They are not compatible with TypeProto in ONNX.
|
||||
ORT_REGISTER_PRIM_TYPE(int32_t);
|
||||
ORT_REGISTER_PRIM_TYPE(float);
|
||||
ORT_REGISTER_PRIM_TYPE(bool);
|
||||
|
|
|
|||
Loading…
Reference in a new issue