From dc1724b0e272b2804fa8eb7ddd5ec700efdc7150 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 19 Nov 2021 10:29:12 +1000 Subject: [PATCH] Reduce DataTypeImpl binary size (#9783) * Reduce the number of virtual methods in DataTypeImpl to reduce binary size. Refactor some helpers to reduce the amount of templatized code. --- .../onnxruntime/core/framework/data_types.h | 177 ++++++++---------- onnxruntime/core/framework/data_types.cc | 49 +++-- 2 files changed, 101 insertions(+), 125 deletions(-) diff --git a/include/onnxruntime/core/framework/data_types.h b/include/onnxruntime/core/framework/data_types.h index f311251385..e3d43f9e0e 100644 --- a/include/onnxruntime/core/framework/data_types.h +++ b/include/onnxruntime/core/framework/data_types.h @@ -33,7 +33,7 @@ namespace onnxruntime { #if !defined(DISABLE_ML_OPS) -//maps (only used by ML ops) +// maps (only used by ML ops) using MapStringToString = std::map; using MapStringToInt64 = std::map; using MapStringToFloat = std::map; @@ -43,7 +43,7 @@ using MapInt64ToInt64 = std::map; using MapInt64ToFloat = std::map; using MapInt64ToDouble = std::map; -//vectors/sequences +// vectors/sequences using VectorMapStringToFloat = std::vector; using VectorMapInt64ToFloat = std::vector; @@ -78,6 +78,22 @@ using CreateFunc = void* (*)(); * */ class DataTypeImpl { + public: + enum class GeneralType { + kInvalid = 0, + kNonTensor = 1, + kTensor = 2, + kTensorSequence = 3, + kSparseTensor = 4, + kOptional = 5 + }; + + GeneralType type_; + size_t size_; + + protected: + DataTypeImpl(GeneralType type, size_t size) : type_{type}, size_{size} {} + public: virtual ~DataTypeImpl() = default; @@ -90,7 +106,7 @@ class DataTypeImpl { */ virtual bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const = 0; - virtual size_t Size() const = 0; + size_t Size() const { return size_; } virtual DeleteFunc GetDeleteFunc() const = 0; @@ -102,20 +118,20 @@ class DataTypeImpl { */ virtual const ONNX_NAMESPACE::TypeProto* GetTypeProto() const = 0; - virtual bool IsTensorType() const { - return false; + bool IsTensorType() const { + return type_ == GeneralType::kTensor; } - virtual bool IsTensorSequenceType() const { - return false; + bool IsTensorSequenceType() const { + return type_ == GeneralType::kTensorSequence; } - virtual bool IsSparseTensorType() const { - return false; + bool IsSparseTensorType() const { + return type_ == GeneralType::kSparseTensor; } - virtual bool IsOptionalType() const { - return false; + bool IsOptionalType() const { + return type_ == GeneralType::kOptional; } // Returns this if this is of tensor-type and null otherwise @@ -383,17 +399,17 @@ struct SetMapTypes { void CopyMutableSeqElement(const ONNX_NAMESPACE::TypeProto&, ONNX_NAMESPACE::TypeProto&); -template -struct SetSequenceType { - static void Set(ONNX_NAMESPACE::TypeProto& proto) { +// helper to create TypeProto with minimal binary size impact +struct SequenceTypeHelper { + template + static const onnx::TypeProto* GetElemType() { MLDataType dt = GetMLDataType::value>::Get(); const auto* elem_proto = dt->GetTypeProto(); -#ifdef ORT_NO_RTTI + return elem_proto; // check for nullptr is in Set + } + + static void Set(const onnx::TypeProto* elem_proto, ONNX_NAMESPACE::TypeProto& proto) { ORT_ENFORCE(elem_proto != nullptr, "expected a registered ONNX type"); -#else - ORT_ENFORCE(elem_proto != nullptr, typeid(T).name(), - " expected to be a registered ONNX type"); -#endif CopyMutableSeqElement(*elem_proto, proto); } }; @@ -403,9 +419,10 @@ struct SetSequenceType { void CopyMutableOptionalElement(const ONNX_NAMESPACE::TypeProto&, ONNX_NAMESPACE::TypeProto&); -template -struct SetOptionalType { - static void Set(ONNX_NAMESPACE::TypeProto& proto) { +// helper to create TypeProto with minimal binary size impact +struct OptionalTypeHelper { + template + static const onnx::TypeProto* GetElemType() { const onnx::TypeProto* elem_proto = nullptr; if (std::is_same::value) { MLDataType dt = DataTypeImpl::GetTensorType(); @@ -413,17 +430,13 @@ struct SetOptionalType { } else if (std::is_same::value) { MLDataType dt = DataTypeImpl::GetSequenceTensorType(); elem_proto = dt->GetTypeProto(); - } else { - // Will not reach here - ORT_ENFORCE(false, "Unsupported type for optional type"); } -#ifdef ORT_NO_RTTI - ORT_ENFORCE(elem_proto != nullptr, "expected a registered ORT type"); -#else - ORT_ENFORCE(elem_proto != nullptr, typeid(T).name(), - " expected to be a registered ORT type"); -#endif + return elem_proto; // check for nullptr is in Set + } + + static void Set(const onnx::TypeProto* elem_proto, ONNX_NAMESPACE::TypeProto& proto) { + ORT_ENFORCE(elem_proto != nullptr, " unregistered or unsupported ORT type for Optional"); CopyMutableOptionalElement(*elem_proto, proto); } }; @@ -445,16 +458,10 @@ class TensorTypeBase : public DataTypeImpl { /// where TypeProto was created ad-hoc and not queried from MLDataType bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override; - bool IsTensorType() const override { - return true; - } - const TensorTypeBase* AsTensorType() const override { return this; } - size_t Size() const override; - DeleteFunc GetDeleteFunc() const override; const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override; @@ -464,8 +471,7 @@ class TensorTypeBase : public DataTypeImpl { ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); } - TensorTypeBase(const TensorTypeBase&) = delete; - TensorTypeBase& operator=(const TensorTypeBase&) = delete; + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorTypeBase); protected: ONNX_NAMESPACE::TypeProto& MutableTypeProto(); @@ -527,10 +533,6 @@ class DisabledTypeBase : public DataTypeImpl { return false; } - size_t Size() const override { - ORT_THROW("Type is disabled in this build."); - } - DeleteFunc GetDeleteFunc() const override { ORT_THROW("Type is disabled in this build."); } @@ -544,7 +546,7 @@ class DisabledTypeBase : public DataTypeImpl { // This must work ONNX_NAMESPACE::TypeProto& MutableTypeProto(); - DisabledTypeBase(); + DisabledTypeBase(DataTypeImpl::GeneralType type, size_t size); ~DisabledTypeBase() override; private: @@ -560,18 +562,12 @@ class SparseTensorTypeBase : public DataTypeImpl { public: static MLDataType Type(); - bool IsSparseTensorType() const override { - return true; - } - const SparseTensorTypeBase* AsSparseTensorType() const override { return this; } bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override; - size_t Size() const override; - DeleteFunc GetDeleteFunc() const override; const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override; @@ -581,8 +577,7 @@ class SparseTensorTypeBase : public DataTypeImpl { ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); } - SparseTensorTypeBase(const SparseTensorTypeBase&) = delete; - SparseTensorTypeBase& operator=(const SparseTensorTypeBase&) = delete; + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SparseTensorTypeBase); protected: ONNX_NAMESPACE::TypeProto& MutableTypeProto(); @@ -624,21 +619,12 @@ class OptionalTypeBase : public DataTypeImpl { public: static MLDataType Type(); - bool IsOptionalType() const override { - return true; - } - const OptionalTypeBase* AsOptionalType() const override { return this; } bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override; - size_t Size() const override { - // should never reach here. - ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); - } - DeleteFunc GetDeleteFunc() const override { // should never reach here. ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); @@ -699,20 +685,26 @@ class OptionalType : #endif private: - OptionalType() { - data_types_internal::SetOptionalType::Set(MutableTypeProto()); +#if !defined(DISABLE_OPTIONAL_TYPE) + OptionalType() +#else + OptionalType() : DisabledTypeBase { DataTypeImpl::GeneralType::kOptional, 0 } +#endif + { + using namespace data_types_internal; + OptionalTypeHelper::Set(OptionalTypeHelper::GetElemType(), MutableTypeProto()); } -}; +}; // namespace onnxruntime /** - * \brief Provide a specialization for your C++ Non-tensor type - * so your implementation FromDataTypeContainer/ToDataTypeContainer - * functions correctly. Otherwise you get a default implementation - * which may not be what you need/want. - * - * This class is used to create OrtValue, fetch data from OrtValue via - * C/C++ APIs - */ + * \brief Provide a specialization for your C++ Non-tensor type + * so your implementation FromDataTypeContainer/ToDataTypeContainer + * functions correctly. Otherwise you get a default implementation + * which may not be what you need/want. + * + * This class is used to create OrtValue, fetch data from OrtValue via + * C/C++ APIs + */ template struct NonTensorTypeConverter { static void FromContainer(MLDataType /*dtype*/, const void* /*data*/, size_t /*data_size*/, OrtValue& /*output*/) { @@ -728,8 +720,6 @@ struct NonTensorTypeConverter { */ class NonTensorTypeBase : public DataTypeImpl { public: - size_t Size() const override = 0; - DeleteFunc GetDeleteFunc() const override = 0; virtual CreateFunc GetCreateFunc() const = 0; @@ -766,7 +756,7 @@ class NonTensorTypeBase : public DataTypeImpl { NonTensorTypeBase& operator=(const NonTensorTypeBase&) = delete; protected: - NonTensorTypeBase(); + NonTensorTypeBase(size_t size); ~NonTensorTypeBase() override; ONNX_NAMESPACE::TypeProto& MutableTypeProto(); @@ -791,10 +781,6 @@ class NonTensorType : public NonTensorTypeBase { } public: - size_t Size() const override { - return sizeof(T); - } - DeleteFunc GetDeleteFunc() const override { return &Delete; } @@ -804,7 +790,7 @@ class NonTensorType : public NonTensorTypeBase { } protected: - NonTensorType() = default; + NonTensorType() : NonTensorTypeBase(sizeof(T)) {} }; #if !defined(DISABLE_ML_OPS) @@ -858,7 +844,9 @@ class SequenceType : public NonTensorType { private: SequenceType() { - data_types_internal::SetSequenceType::Set(this->MutableTypeProto()); + using namespace data_types_internal; + SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType(), + this->MutableTypeProto()); } }; @@ -873,10 +861,6 @@ class SequenceTensorTypeBase : public DataTypeImpl { bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override; - bool IsTensorSequenceType() const override { - return true; - } - const SequenceTensorTypeBase* AsSequenceTensorType() const override { return this; } @@ -886,8 +870,6 @@ class SequenceTensorTypeBase : public DataTypeImpl { ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); } - size_t Size() const override; - DeleteFunc GetDeleteFunc() const override; const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override; @@ -931,7 +913,8 @@ class SequenceTensorType : public SequenceTensorTypeBase { private: SequenceTensorType() { - data_types_internal::SetSequenceType::Set(this->MutableTypeProto()); + using namespace data_types_internal; + SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType(), MutableTypeProto()); } }; @@ -998,11 +981,8 @@ class PrimitiveDataTypeBase : public DataTypeImpl { } protected: - PrimitiveDataTypeBase() = default; - - void SetDataType(int32_t data_type) { - data_type_ = data_type; - } + PrimitiveDataTypeBase(size_t size, int32_t data_type) + : DataTypeImpl{GeneralType::kNonTensor, size}, data_type_{data_type} {} private: int32_t data_type_; @@ -1015,7 +995,7 @@ class PrimitiveDataTypeBase : public DataTypeImpl { * * \param T - primitive data type * - */ + */ template class PrimitiveDataType : public PrimitiveDataTypeBase { private: @@ -1026,17 +1006,14 @@ class PrimitiveDataType : public PrimitiveDataTypeBase { public: static MLDataType Type(); - size_t Size() const override { - return sizeof(T); - } - DeleteFunc GetDeleteFunc() const override { return &Delete; } private: - PrimitiveDataType() { - this->SetDataType(data_types_internal::TensorElementTypeSetter::GetElementType()); + PrimitiveDataType() + : PrimitiveDataTypeBase{sizeof(T), + data_types_internal::TensorElementTypeSetter::GetElementType()} { } }; diff --git a/onnxruntime/core/framework/data_types.cc b/onnxruntime/core/framework/data_types.cc index 8038d819fe..2b33c99beb 100644 --- a/onnxruntime/core/framework/data_types.cc +++ b/onnxruntime/core/framework/data_types.cc @@ -59,10 +59,10 @@ MLDataType DataTypeImpl::GetType() { return SequenceTensorTypeBase::Type(); } -//static bool IsTensorTypeScalar(const ONNX_NAMESPACE::TypeProto_Tensor& tensor_type_proto) { -// int sz = tensor_type_proto.shape().dim_size(); -// return sz == 0 || sz == 1; -//} +// static bool IsTensorTypeScalar(const ONNX_NAMESPACE::TypeProto_Tensor& tensor_type_proto) { +// int sz = tensor_type_proto.shape().dim_size(); +// return sz == 0 || sz == 1; +// } namespace data_types_internal { @@ -368,15 +368,13 @@ const ONNX_NAMESPACE::TypeProto* TensorTypeBase::GetTypeProto() const { return impl_->GetProto(); } -TensorTypeBase::TensorTypeBase() : impl_(new Impl()) {} +TensorTypeBase::TensorTypeBase() + : DataTypeImpl{DataTypeImpl::GeneralType::kTensor, sizeof(Tensor)}, + impl_(new Impl()) {} TensorTypeBase::~TensorTypeBase() { delete impl_; } -size_t TensorTypeBase::Size() const { - return sizeof(Tensor); -} - template static void Delete(void* p) { delete static_cast(p); @@ -417,7 +415,10 @@ MLDataType TensorTypeBase::Type() { struct SparseTensorTypeBase::Impl : public data_types_internal::TypeProtoImpl { }; -SparseTensorTypeBase::SparseTensorTypeBase() : impl_(new Impl()) {} +SparseTensorTypeBase::SparseTensorTypeBase() + : DataTypeImpl{DataTypeImpl::GeneralType::kSparseTensor, sizeof(SparseTensor)}, + impl_(new Impl()) {} + SparseTensorTypeBase::~SparseTensorTypeBase() { delete impl_; } @@ -437,10 +438,6 @@ bool SparseTensorTypeBase::IsCompatible(const ONNX_NAMESPACE::TypeProto& type_pr return data_types_internal::IsCompatible(thisProto->sparse_tensor_type(), type_proto.sparse_tensor_type()); } -size_t SparseTensorTypeBase::Size() const { - return sizeof(SparseTensor); -} - DeleteFunc SparseTensorTypeBase::GetDeleteFunc() const { return &Delete; } @@ -464,7 +461,9 @@ MLDataType SparseTensorTypeBase::Type() { struct SequenceTensorTypeBase::Impl : public data_types_internal::TypeProtoImpl { }; -SequenceTensorTypeBase::SequenceTensorTypeBase() : impl_(new Impl()) {} +SequenceTensorTypeBase::SequenceTensorTypeBase() + : DataTypeImpl{DataTypeImpl::GeneralType::kTensorSequence, sizeof(TensorSeq)}, + impl_(new Impl()) {} SequenceTensorTypeBase::~SequenceTensorTypeBase() { delete impl_; @@ -489,10 +488,6 @@ bool SequenceTensorTypeBase::IsCompatible(const ONNX_NAMESPACE::TypeProto& type_ return data_types_internal::IsCompatible(thisProto->sequence_type(), type_proto.sequence_type()); } -size_t SequenceTensorTypeBase::Size() const { - return sizeof(TensorSeq); -} - DeleteFunc SequenceTensorTypeBase::GetDeleteFunc() const { return &Delete; } @@ -516,7 +511,8 @@ MLDataType SequenceTensorTypeBase::Type() { struct OptionalTypeBase::Impl : public data_types_internal::TypeProtoImpl { }; -OptionalTypeBase::OptionalTypeBase() : impl_(new Impl()) {} +OptionalTypeBase::OptionalTypeBase() : DataTypeImpl{DataTypeImpl::GeneralType::kOptional, 0}, + impl_(new Impl()) {} OptionalTypeBase::~OptionalTypeBase() { delete impl_; @@ -557,7 +553,8 @@ MLDataType OptionalTypeBase::Type() { struct DisabledTypeBase::Impl : public data_types_internal::TypeProtoImpl { }; -DisabledTypeBase::DisabledTypeBase() : impl_(new Impl()) {} +DisabledTypeBase::DisabledTypeBase(DataTypeImpl::GeneralType type, size_t size) + : DataTypeImpl{type, size}, impl_(new Impl()) {} DisabledTypeBase::~DisabledTypeBase() { delete impl_; @@ -572,7 +569,7 @@ ONNX_NAMESPACE::TypeProto& DisabledTypeBase::MutableTypeProto() { } MLDataType DisabledTypeBase::Type() { - static DisabledTypeBase disabled_base; + static DisabledTypeBase disabled_base{GeneralType::kInvalid, 0}; return &disabled_base; } #endif @@ -580,7 +577,9 @@ MLDataType DisabledTypeBase::Type() { /// NoTensorTypeBase struct NonTensorTypeBase::Impl : public data_types_internal::TypeProtoImpl {}; -NonTensorTypeBase::NonTensorTypeBase() : impl_(new Impl()) { +NonTensorTypeBase::NonTensorTypeBase(size_t size) + : DataTypeImpl{DataTypeImpl::GeneralType::kNonTensor, size}, + impl_(new Impl()) { } NonTensorTypeBase::~NonTensorTypeBase() { @@ -1037,8 +1036,8 @@ MLDataType DataTypeImpl::TypeFromProto(const ONNX_NAMESPACE::TypeProto& proto) { return type; } -//Below are the types the we need to execute the runtime -//They are not compatible with TypeProto in ONNX. +// Below are the types the we need to execute the runtime +// They are not compatible with TypeProto in ONNX. ORT_REGISTER_PRIM_TYPE(int32_t); ORT_REGISTER_PRIM_TYPE(float); ORT_REGISTER_PRIM_TYPE(bool);