Reduce DataTypeImpl binary size (#9783)

* Reduce the number of virtual methods in DataTypeImpl to reduce binary size.
Refactor some helpers to reduce the amount of templatized code.
This commit is contained in:
Scott McKay 2021-11-19 10:29:12 +10:00 committed by GitHub
parent f390347c11
commit dc1724b0e2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 101 additions and 125 deletions

View file

@ -33,7 +33,7 @@ namespace onnxruntime {
#if !defined(DISABLE_ML_OPS)
//maps (only used by ML ops)
// maps (only used by ML ops)
using MapStringToString = std::map<std::string, std::string>;
using MapStringToInt64 = std::map<std::string, int64_t>;
using MapStringToFloat = std::map<std::string, float>;
@ -43,7 +43,7 @@ using MapInt64ToInt64 = std::map<int64_t, int64_t>;
using MapInt64ToFloat = std::map<int64_t, float>;
using MapInt64ToDouble = std::map<int64_t, double>;
//vectors/sequences
// vectors/sequences
using VectorMapStringToFloat = std::vector<MapStringToFloat>;
using VectorMapInt64ToFloat = std::vector<MapInt64ToFloat>;
@ -78,6 +78,22 @@ using CreateFunc = void* (*)();
*
*/
class DataTypeImpl {
public:
enum class GeneralType {
kInvalid = 0,
kNonTensor = 1,
kTensor = 2,
kTensorSequence = 3,
kSparseTensor = 4,
kOptional = 5
};
GeneralType type_;
size_t size_;
protected:
DataTypeImpl(GeneralType type, size_t size) : type_{type}, size_{size} {}
public:
virtual ~DataTypeImpl() = default;
@ -90,7 +106,7 @@ class DataTypeImpl {
*/
virtual bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const = 0;
virtual size_t Size() const = 0;
size_t Size() const { return size_; }
virtual DeleteFunc GetDeleteFunc() const = 0;
@ -102,20 +118,20 @@ class DataTypeImpl {
*/
virtual const ONNX_NAMESPACE::TypeProto* GetTypeProto() const = 0;
virtual bool IsTensorType() const {
return false;
bool IsTensorType() const {
return type_ == GeneralType::kTensor;
}
virtual bool IsTensorSequenceType() const {
return false;
bool IsTensorSequenceType() const {
return type_ == GeneralType::kTensorSequence;
}
virtual bool IsSparseTensorType() const {
return false;
bool IsSparseTensorType() const {
return type_ == GeneralType::kSparseTensor;
}
virtual bool IsOptionalType() const {
return false;
bool IsOptionalType() const {
return type_ == GeneralType::kOptional;
}
// Returns this if this is of tensor-type and null otherwise
@ -383,17 +399,17 @@ struct SetMapTypes {
void CopyMutableSeqElement(const ONNX_NAMESPACE::TypeProto&,
ONNX_NAMESPACE::TypeProto&);
template <typename T>
struct SetSequenceType {
static void Set(ONNX_NAMESPACE::TypeProto& proto) {
// helper to create TypeProto with minimal binary size impact
struct SequenceTypeHelper {
template <typename T>
static const onnx::TypeProto* GetElemType() {
MLDataType dt = GetMLDataType<T, IsTensorContainedType<T>::value>::Get();
const auto* elem_proto = dt->GetTypeProto();
#ifdef ORT_NO_RTTI
return elem_proto; // check for nullptr is in Set
}
static void Set(const onnx::TypeProto* elem_proto, ONNX_NAMESPACE::TypeProto& proto) {
ORT_ENFORCE(elem_proto != nullptr, "expected a registered ONNX type");
#else
ORT_ENFORCE(elem_proto != nullptr, typeid(T).name(),
" expected to be a registered ONNX type");
#endif
CopyMutableSeqElement(*elem_proto, proto);
}
};
@ -403,9 +419,10 @@ struct SetSequenceType {
void CopyMutableOptionalElement(const ONNX_NAMESPACE::TypeProto&,
ONNX_NAMESPACE::TypeProto&);
template <typename T, typename elemT>
struct SetOptionalType {
static void Set(ONNX_NAMESPACE::TypeProto& proto) {
// helper to create TypeProto with minimal binary size impact
struct OptionalTypeHelper {
template <typename T, typename elemT>
static const onnx::TypeProto* GetElemType() {
const onnx::TypeProto* elem_proto = nullptr;
if (std::is_same<T, Tensor>::value) {
MLDataType dt = DataTypeImpl::GetTensorType<elemT>();
@ -413,17 +430,13 @@ struct SetOptionalType {
} else if (std::is_same<T, TensorSeq>::value) {
MLDataType dt = DataTypeImpl::GetSequenceTensorType<elemT>();
elem_proto = dt->GetTypeProto();
} else {
// Will not reach here
ORT_ENFORCE(false, "Unsupported type for optional type");
}
#ifdef ORT_NO_RTTI
ORT_ENFORCE(elem_proto != nullptr, "expected a registered ORT type");
#else
ORT_ENFORCE(elem_proto != nullptr, typeid(T).name(),
" expected to be a registered ORT type");
#endif
return elem_proto; // check for nullptr is in Set
}
static void Set(const onnx::TypeProto* elem_proto, ONNX_NAMESPACE::TypeProto& proto) {
ORT_ENFORCE(elem_proto != nullptr, " unregistered or unsupported ORT type for Optional");
CopyMutableOptionalElement(*elem_proto, proto);
}
};
@ -445,16 +458,10 @@ class TensorTypeBase : public DataTypeImpl {
/// where TypeProto was created ad-hoc and not queried from MLDataType
bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
bool IsTensorType() const override {
return true;
}
const TensorTypeBase* AsTensorType() const override {
return this;
}
size_t Size() const override;
DeleteFunc GetDeleteFunc() const override;
const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
@ -464,8 +471,7 @@ class TensorTypeBase : public DataTypeImpl {
ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
}
TensorTypeBase(const TensorTypeBase&) = delete;
TensorTypeBase& operator=(const TensorTypeBase&) = delete;
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorTypeBase);
protected:
ONNX_NAMESPACE::TypeProto& MutableTypeProto();
@ -527,10 +533,6 @@ class DisabledTypeBase : public DataTypeImpl {
return false;
}
size_t Size() const override {
ORT_THROW("Type is disabled in this build.");
}
DeleteFunc GetDeleteFunc() const override {
ORT_THROW("Type is disabled in this build.");
}
@ -544,7 +546,7 @@ class DisabledTypeBase : public DataTypeImpl {
// This must work
ONNX_NAMESPACE::TypeProto& MutableTypeProto();
DisabledTypeBase();
DisabledTypeBase(DataTypeImpl::GeneralType type, size_t size);
~DisabledTypeBase() override;
private:
@ -560,18 +562,12 @@ class SparseTensorTypeBase : public DataTypeImpl {
public:
static MLDataType Type();
bool IsSparseTensorType() const override {
return true;
}
const SparseTensorTypeBase* AsSparseTensorType() const override {
return this;
}
bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
size_t Size() const override;
DeleteFunc GetDeleteFunc() const override;
const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
@ -581,8 +577,7 @@ class SparseTensorTypeBase : public DataTypeImpl {
ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
}
SparseTensorTypeBase(const SparseTensorTypeBase&) = delete;
SparseTensorTypeBase& operator=(const SparseTensorTypeBase&) = delete;
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SparseTensorTypeBase);
protected:
ONNX_NAMESPACE::TypeProto& MutableTypeProto();
@ -624,21 +619,12 @@ class OptionalTypeBase : public DataTypeImpl {
public:
static MLDataType Type();
bool IsOptionalType() const override {
return true;
}
const OptionalTypeBase* AsOptionalType() const override {
return this;
}
bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
size_t Size() const override {
// should never reach here.
ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
}
DeleteFunc GetDeleteFunc() const override {
// should never reach here.
ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
@ -699,20 +685,26 @@ class OptionalType :
#endif
private:
OptionalType() {
data_types_internal::SetOptionalType<T, elemT>::Set(MutableTypeProto());
#if !defined(DISABLE_OPTIONAL_TYPE)
OptionalType()
#else
OptionalType() : DisabledTypeBase { DataTypeImpl::GeneralType::kOptional, 0 }
#endif
{
using namespace data_types_internal;
OptionalTypeHelper::Set(OptionalTypeHelper::GetElemType<T, elemT>(), MutableTypeProto());
}
};
}; // namespace onnxruntime
/**
* \brief Provide a specialization for your C++ Non-tensor type
* so your implementation FromDataTypeContainer/ToDataTypeContainer
* functions correctly. Otherwise you get a default implementation
* which may not be what you need/want.
*
* This class is used to create OrtValue, fetch data from OrtValue via
* C/C++ APIs
*/
* \brief Provide a specialization for your C++ Non-tensor type
* so your implementation FromDataTypeContainer/ToDataTypeContainer
* functions correctly. Otherwise you get a default implementation
* which may not be what you need/want.
*
* This class is used to create OrtValue, fetch data from OrtValue via
* C/C++ APIs
*/
template <class T>
struct NonTensorTypeConverter {
static void FromContainer(MLDataType /*dtype*/, const void* /*data*/, size_t /*data_size*/, OrtValue& /*output*/) {
@ -728,8 +720,6 @@ struct NonTensorTypeConverter {
*/
class NonTensorTypeBase : public DataTypeImpl {
public:
size_t Size() const override = 0;
DeleteFunc GetDeleteFunc() const override = 0;
virtual CreateFunc GetCreateFunc() const = 0;
@ -766,7 +756,7 @@ class NonTensorTypeBase : public DataTypeImpl {
NonTensorTypeBase& operator=(const NonTensorTypeBase&) = delete;
protected:
NonTensorTypeBase();
NonTensorTypeBase(size_t size);
~NonTensorTypeBase() override;
ONNX_NAMESPACE::TypeProto& MutableTypeProto();
@ -791,10 +781,6 @@ class NonTensorType : public NonTensorTypeBase {
}
public:
size_t Size() const override {
return sizeof(T);
}
DeleteFunc GetDeleteFunc() const override {
return &Delete;
}
@ -804,7 +790,7 @@ class NonTensorType : public NonTensorTypeBase {
}
protected:
NonTensorType() = default;
NonTensorType() : NonTensorTypeBase(sizeof(T)) {}
};
#if !defined(DISABLE_ML_OPS)
@ -858,7 +844,9 @@ class SequenceType : public NonTensorType<CPPType> {
private:
SequenceType() {
data_types_internal::SetSequenceType<typename CPPType::value_type>::Set(this->MutableTypeProto());
using namespace data_types_internal;
SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType<typename CPPType::value_type>(),
this->MutableTypeProto());
}
};
@ -873,10 +861,6 @@ class SequenceTensorTypeBase : public DataTypeImpl {
bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
bool IsTensorSequenceType() const override {
return true;
}
const SequenceTensorTypeBase* AsSequenceTensorType() const override {
return this;
}
@ -886,8 +870,6 @@ class SequenceTensorTypeBase : public DataTypeImpl {
ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
}
size_t Size() const override;
DeleteFunc GetDeleteFunc() const override;
const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
@ -931,7 +913,8 @@ class SequenceTensorType : public SequenceTensorTypeBase {
private:
SequenceTensorType() {
data_types_internal::SetSequenceType<TensorElemType>::Set(this->MutableTypeProto());
using namespace data_types_internal;
SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType<TensorElemType>(), MutableTypeProto());
}
};
@ -998,11 +981,8 @@ class PrimitiveDataTypeBase : public DataTypeImpl {
}
protected:
PrimitiveDataTypeBase() = default;
void SetDataType(int32_t data_type) {
data_type_ = data_type;
}
PrimitiveDataTypeBase(size_t size, int32_t data_type)
: DataTypeImpl{GeneralType::kNonTensor, size}, data_type_{data_type} {}
private:
int32_t data_type_;
@ -1015,7 +995,7 @@ class PrimitiveDataTypeBase : public DataTypeImpl {
*
* \param T - primitive data type
*
*/
*/
template <typename T>
class PrimitiveDataType : public PrimitiveDataTypeBase {
private:
@ -1026,17 +1006,14 @@ class PrimitiveDataType : public PrimitiveDataTypeBase {
public:
static MLDataType Type();
size_t Size() const override {
return sizeof(T);
}
DeleteFunc GetDeleteFunc() const override {
return &Delete;
}
private:
PrimitiveDataType() {
this->SetDataType(data_types_internal::TensorElementTypeSetter<T>::GetElementType());
PrimitiveDataType()
: PrimitiveDataTypeBase{sizeof(T),
data_types_internal::TensorElementTypeSetter<T>::GetElementType()} {
}
};

View file

@ -59,10 +59,10 @@ MLDataType DataTypeImpl::GetType<TensorSeq>() {
return SequenceTensorTypeBase::Type();
}
//static bool IsTensorTypeScalar(const ONNX_NAMESPACE::TypeProto_Tensor& tensor_type_proto) {
// int sz = tensor_type_proto.shape().dim_size();
// return sz == 0 || sz == 1;
//}
// static bool IsTensorTypeScalar(const ONNX_NAMESPACE::TypeProto_Tensor& tensor_type_proto) {
// int sz = tensor_type_proto.shape().dim_size();
// return sz == 0 || sz == 1;
// }
namespace data_types_internal {
@ -368,15 +368,13 @@ const ONNX_NAMESPACE::TypeProto* TensorTypeBase::GetTypeProto() const {
return impl_->GetProto();
}
TensorTypeBase::TensorTypeBase() : impl_(new Impl()) {}
TensorTypeBase::TensorTypeBase()
: DataTypeImpl{DataTypeImpl::GeneralType::kTensor, sizeof(Tensor)},
impl_(new Impl()) {}
TensorTypeBase::~TensorTypeBase() {
delete impl_;
}
size_t TensorTypeBase::Size() const {
return sizeof(Tensor);
}
template <typename T>
static void Delete(void* p) {
delete static_cast<T*>(p);
@ -417,7 +415,10 @@ MLDataType TensorTypeBase::Type() {
struct SparseTensorTypeBase::Impl : public data_types_internal::TypeProtoImpl {
};
SparseTensorTypeBase::SparseTensorTypeBase() : impl_(new Impl()) {}
SparseTensorTypeBase::SparseTensorTypeBase()
: DataTypeImpl{DataTypeImpl::GeneralType::kSparseTensor, sizeof(SparseTensor)},
impl_(new Impl()) {}
SparseTensorTypeBase::~SparseTensorTypeBase() {
delete impl_;
}
@ -437,10 +438,6 @@ bool SparseTensorTypeBase::IsCompatible(const ONNX_NAMESPACE::TypeProto& type_pr
return data_types_internal::IsCompatible(thisProto->sparse_tensor_type(), type_proto.sparse_tensor_type());
}
size_t SparseTensorTypeBase::Size() const {
return sizeof(SparseTensor);
}
DeleteFunc SparseTensorTypeBase::GetDeleteFunc() const {
return &Delete<SparseTensor>;
}
@ -464,7 +461,9 @@ MLDataType SparseTensorTypeBase::Type() {
struct SequenceTensorTypeBase::Impl : public data_types_internal::TypeProtoImpl {
};
SequenceTensorTypeBase::SequenceTensorTypeBase() : impl_(new Impl()) {}
SequenceTensorTypeBase::SequenceTensorTypeBase()
: DataTypeImpl{DataTypeImpl::GeneralType::kTensorSequence, sizeof(TensorSeq)},
impl_(new Impl()) {}
SequenceTensorTypeBase::~SequenceTensorTypeBase() {
delete impl_;
@ -489,10 +488,6 @@ bool SequenceTensorTypeBase::IsCompatible(const ONNX_NAMESPACE::TypeProto& type_
return data_types_internal::IsCompatible(thisProto->sequence_type(), type_proto.sequence_type());
}
size_t SequenceTensorTypeBase::Size() const {
return sizeof(TensorSeq);
}
DeleteFunc SequenceTensorTypeBase::GetDeleteFunc() const {
return &Delete<TensorSeq>;
}
@ -516,7 +511,8 @@ MLDataType SequenceTensorTypeBase::Type() {
struct OptionalTypeBase::Impl : public data_types_internal::TypeProtoImpl {
};
OptionalTypeBase::OptionalTypeBase() : impl_(new Impl()) {}
OptionalTypeBase::OptionalTypeBase() : DataTypeImpl{DataTypeImpl::GeneralType::kOptional, 0},
impl_(new Impl()) {}
OptionalTypeBase::~OptionalTypeBase() {
delete impl_;
@ -557,7 +553,8 @@ MLDataType OptionalTypeBase::Type() {
struct DisabledTypeBase::Impl : public data_types_internal::TypeProtoImpl {
};
DisabledTypeBase::DisabledTypeBase() : impl_(new Impl()) {}
DisabledTypeBase::DisabledTypeBase(DataTypeImpl::GeneralType type, size_t size)
: DataTypeImpl{type, size}, impl_(new Impl()) {}
DisabledTypeBase::~DisabledTypeBase() {
delete impl_;
@ -572,7 +569,7 @@ ONNX_NAMESPACE::TypeProto& DisabledTypeBase::MutableTypeProto() {
}
MLDataType DisabledTypeBase::Type() {
static DisabledTypeBase disabled_base;
static DisabledTypeBase disabled_base{GeneralType::kInvalid, 0};
return &disabled_base;
}
#endif
@ -580,7 +577,9 @@ MLDataType DisabledTypeBase::Type() {
/// NoTensorTypeBase
struct NonTensorTypeBase::Impl : public data_types_internal::TypeProtoImpl {};
NonTensorTypeBase::NonTensorTypeBase() : impl_(new Impl()) {
NonTensorTypeBase::NonTensorTypeBase(size_t size)
: DataTypeImpl{DataTypeImpl::GeneralType::kNonTensor, size},
impl_(new Impl()) {
}
NonTensorTypeBase::~NonTensorTypeBase() {
@ -1037,8 +1036,8 @@ MLDataType DataTypeImpl::TypeFromProto(const ONNX_NAMESPACE::TypeProto& proto) {
return type;
}
//Below are the types the we need to execute the runtime
//They are not compatible with TypeProto in ONNX.
// Below are the types the we need to execute the runtime
// They are not compatible with TypeProto in ONNX.
ORT_REGISTER_PRIM_TYPE(int32_t);
ORT_REGISTER_PRIM_TYPE(float);
ORT_REGISTER_PRIM_TYPE(bool);