From 3433576fd39bb6451fb520e208a2f34a07ba4c7b Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Tue, 27 Oct 2020 10:32:06 -0700 Subject: [PATCH] Support for Sparse Initializers (#5540) Introduce sparse_initializers support. Convert them to dense on model load and prune graph_proto_ so they don't consume space. Convert back to sparse on ORT Format model save. Implement serializing sparse initializers to OrtFormat. Fix Model::ToProto() to return original sparse initializers Set a flag that graph_sync is needed when loading a simple ORT Format model. otherwise nothing is resolved. Add ORT Format history to README.md ifdef MINIMAL build for DenseToSparseTensorInitializer Allow duplicate initializers to support existing models. Issue a warning instead of aborting. * Revert "Remove SparseTensor support from minimal build. (#5114)" This reverts commit 59ee8ffb17ffce7d0bce9692d8d58555df909ff9. Signed-off-by: Dmitri Smirnov --- .../onnxruntime/core/framework/data_types.h | 18 +- include/onnxruntime/core/framework/ml_value.h | 5 - .../onnxruntime/core/framework/op_kernel.h | 26 +- .../core/framework/sparse_tensor.h | 4 - include/onnxruntime/core/graph/basic_types.h | 2 + include/onnxruntime/core/graph/graph.h | 7 +- .../core/flatbuffers/flatbuffers_utils.cc | 25 +- onnxruntime/core/flatbuffers/schema/README.md | 9 + onnxruntime/core/flatbuffers/schema/ort.fbs | 11 +- onnxruntime/core/flatbuffers/schema/ort.fbs.h | 129 +++++++++- onnxruntime/core/framework/data_types.cc | 27 -- onnxruntime/core/framework/execution_frame.cc | 7 - .../core/framework/onnxruntime_typeinfo.cc | 29 +-- onnxruntime/core/framework/op_kernel.cc | 2 - onnxruntime/core/framework/sparse_tensor.cc | 4 - .../core/framework/tensor_type_and_shape.cc | 33 +-- .../core/framework/tensorprotoutils.cc | 239 ++++++++++++------ onnxruntime/core/framework/tensorprotoutils.h | 12 +- onnxruntime/core/graph/graph.cc | 136 ++++++++-- .../core/graph/graph_flatbuffers_utils.cc | 67 ++++- .../core/graph/graph_flatbuffers_utils.h | 8 + onnxruntime/core/graph/model.cc | 11 +- onnxruntime/core/graph/model.h | 2 +- .../cpu/element_wise_ranged_transform.h | 2 +- .../DmlExecutionProvider/src/ErrorHandling.h | 2 +- onnxruntime/core/session/inference_session.cc | 24 +- .../test/framework/ort_model_only_test.cc | 32 +++ .../test/framework/sparse_kernels_test.cc | 222 +++++++++++----- onnxruntime/test/ir/graph_test.cc | 149 +++++++++++ .../testdata/sparse_initializer_handling.onnx | Bin 0 -> 324 bytes .../sparse_initializer_handling.onnx.ort | Bin 0 -> 1488 bytes orttraining/orttraining/models/gpt2/main.cc | 4 +- tools/ci_build/build.py | 4 +- 33 files changed, 925 insertions(+), 327 deletions(-) create mode 100644 onnxruntime/test/testdata/sparse_initializer_handling.onnx create mode 100644 onnxruntime/test/testdata/sparse_initializer_handling.onnx.ort diff --git a/include/onnxruntime/core/framework/data_types.h b/include/onnxruntime/core/framework/data_types.h index e136f04147..5f539cfc43 100644 --- a/include/onnxruntime/core/framework/data_types.h +++ b/include/onnxruntime/core/framework/data_types.h @@ -46,12 +46,10 @@ using VectorInt64 = std::vector; class DataTypeImpl; class TensorTypeBase; +class SparseTensorTypeBase; class SequenceTensorTypeBase; class NonTensorTypeBase; class PrimitiveDataTypeBase; -#if !defined(ORT_MINIMAL_BUILD) -class SparseTensorTypeBase; -#endif // MLFloat16 union MLFloat16 { @@ -235,12 +233,10 @@ class DataTypeImpl { return nullptr; } -#if !defined(ORT_MINIMAL_BUILD) // Returns this if this is of sparse-tensor-type and null otherwise virtual const SparseTensorTypeBase* AsSparseTensorType() const { return nullptr; } -#endif virtual const NonTensorTypeBase* AsNonTensorTypeBase() const { return nullptr; @@ -263,11 +259,9 @@ class DataTypeImpl { template static MLDataType GetSequenceTensorType(); -#if !defined(ORT_MINIMAL_BUILD) // Return the MLDataType for a concrete sparse tensor type. template static MLDataType GetSparseTensorType(); -#endif /** * Convert an ONNX TypeProto to onnxruntime DataTypeImpl. @@ -279,10 +273,8 @@ class DataTypeImpl { static MLDataType TypeFromProto(const ONNX_NAMESPACE::TypeProto& proto); static const TensorTypeBase* TensorTypeFromONNXEnum(int type); - static const NonTensorTypeBase* SequenceTensorTypeFromONNXEnum(int type); -#if !defined(ORT_MINIMAL_BUILD) static const SparseTensorTypeBase* SparseTensorTypeFromONNXEnum(int type); -#endif + static const NonTensorTypeBase* SequenceTensorTypeFromONNXEnum(int type); static const char* ToString(MLDataType type); // Registers ONNX_NAMESPACE::DataType (internalized string) with @@ -405,7 +397,6 @@ struct IsTensorContainedType : public IsAnyOf { }; -#if !defined(ORT_MINIMAL_BUILD) /// Use "IsSparseTensorContainedType::value" to test if a type T /// is permitted as the element-type of a sparse-tensor. @@ -414,7 +405,6 @@ struct IsSparseTensorContainedType : public IsAnyOf { }; -#endif /// This template's Get() returns a corresponding MLDataType /// It dispatches the call to either GetTensorType<>() or @@ -566,7 +556,6 @@ class TensorType : public TensorTypeBase { } }; -#if !defined(ORT_MINIMAL_BUILD) /// Common base-class for all sparse-tensors (with different element types). class SparseTensorTypeBase : public DataTypeImpl { public: @@ -626,7 +615,6 @@ class SparseTensorType : public SparseTensorTypeBase { TensorElementTypeSetter::SetSparseTensorElementType(mutable_type_proto()); } }; -#endif // !defined(ORT_MINIMAL_BUILD) /** * \brief Provide a specialization for your C++ Non-tensor type @@ -977,7 +965,6 @@ class PrimitiveDataType : public PrimitiveDataTypeBase { return TensorType::Type(); \ } -#if !defined(ORT_MINIMAL_BUILD) #define ORT_REGISTER_SPARSE_TENSOR_TYPE(ELEM_TYPE) \ template <> \ MLDataType SparseTensorType::Type() { \ @@ -988,7 +975,6 @@ class PrimitiveDataType : public PrimitiveDataTypeBase { MLDataType DataTypeImpl::GetSparseTensorType() { \ return SparseTensorType::Type(); \ } -#endif #if !defined(DISABLE_ML_OPS) #define ORT_REGISTER_MAP(TYPE) \ diff --git a/include/onnxruntime/core/framework/ml_value.h b/include/onnxruntime/core/framework/ml_value.h index a304e10447..27ffa58875 100644 --- a/include/onnxruntime/core/framework/ml_value.h +++ b/include/onnxruntime/core/framework/ml_value.h @@ -11,10 +11,7 @@ #include "core/framework/tensor.h" namespace onnxruntime { -#if !defined(ORT_MINIMAL_BUILD) class SparseTensor; -#endif - class TensorSeq; } // namespace onnxruntime @@ -109,7 +106,6 @@ inline onnxruntime::TensorSeq* OrtValue::GetMutable() { return static_cast(data_.get()); } -#if !defined(ORT_MINIMAL_BUILD) template <> inline const onnxruntime::SparseTensor& OrtValue::Get() const { ORT_ENFORCE(IsSparseTensor(), "Trying to get a SparseTensor, but got: ", onnxruntime::DataTypeImpl::ToString(type_)); @@ -121,7 +117,6 @@ inline onnxruntime::SparseTensor* OrtValue::GetMutable(data_.get()); } -#endif //TODO: remove the following line #define MLValue OrtValue diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index 6daa237cc3..32c00d9866 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -143,14 +143,12 @@ class OpKernelContext { return *output_ptr; } -#if !defined(ORT_MINIMAL_BUILD) // Fetch a sparse-tensor output corresponding to the specified index. // num_values must specify the number of non-zero values (commonly known as NNZ/nnz), // and shape must specify the shape of the underlying dense-tensor. // Memory allocation for the output may happen when this method is invoked, // unless static optimization pre-allocates it. SparseTensor* Output(int index, size_t num_values, const TensorShape& shape); -#endif // Retrieve indexed shape obtained from memory planning before actual // computation. If the indexed shape cannot be inferred, this function returns @@ -437,18 +435,18 @@ using BuildKernelCreateInfoFn = KernelCreateInfo (*)(); #define ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type1, type2, name) \ provider##_##name##_##domain##_ver##startver##_##endver##_##type1##_##type2 -#define ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX(name, domain, startver, endver, type1, type2, provider, builder, ...) \ - class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type1, type2, name); \ - template <> \ - KernelCreateInfo \ - BuildKernelCreateInfo() { \ - return KernelCreateInfo( \ - builder.SetName(#name) \ - .SetDomain(domain) \ - .SinceVersion(startver, endver) \ - .Provider(provider) \ - .Build(), \ - static_cast([](const OpKernelInfo& info) -> OpKernel* { return new __VA_ARGS__(info); })); \ +#define ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX(name, domain, startver, endver, type1, type2, provider, builder, ...) \ + class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type1, type2, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name) \ + .SetDomain(domain) \ + .SinceVersion(startver, endver) \ + .Provider(provider) \ + .Build(), \ + static_cast([](const OpKernelInfo& info) -> OpKernel* { return new __VA_ARGS__(info); })); \ } // Use within macro definitions to create a custom vector of constraints. diff --git a/include/onnxruntime/core/framework/sparse_tensor.h b/include/onnxruntime/core/framework/sparse_tensor.h index 05afe94a5f..ec0abb91ec 100644 --- a/include/onnxruntime/core/framework/sparse_tensor.h +++ b/include/onnxruntime/core/framework/sparse_tensor.h @@ -3,8 +3,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if !defined(ORT_MINIMAL_BUILD) - #include "core/framework/data_types.h" #include "core/framework/tensor_shape.h" #include "core/framework/tensor.h" @@ -72,5 +70,3 @@ class SparseTensor final { }; } // namespace onnxruntime - -#endif diff --git a/include/onnxruntime/core/graph/basic_types.h b/include/onnxruntime/core/graph/basic_types.h index a0085a5ad9..7c1203fc1f 100644 --- a/include/onnxruntime/core/graph/basic_types.h +++ b/include/onnxruntime/core/graph/basic_types.h @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include #include #include @@ -10,6 +11,7 @@ namespace ONNX_NAMESPACE { class ValueInfoProto; class TensorProto; +class SparseTensorProto; class TypeProto; class AttributeProto; // define types that would come from the ONNX library if we were building against it. diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index e726bc0d7c..0c153350cf 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -320,10 +320,7 @@ class Node { ADD_ATTR_INTERFACES(std::string) ADD_ATTR_INTERFACES(ONNX_NAMESPACE::TensorProto) ADD_ATTR_INTERFACES(ONNX_NAMESPACE::GraphProto) - -#if !defined(ORT_MINIMAL_BUILD) ADD_ATTR_INTERFACES(ONNX_NAMESPACE::SparseTensorProto) -#endif /** Gets the Node's attributes. */ const NodeAttributes& GetAttributes() const noexcept { return attributes_; } @@ -1265,6 +1262,10 @@ class Graph { InitializedTensorSet name_to_initial_tensor_; + std::unordered_set, + std::hash, std::equal_to> + sparse_tensor_names_; + #if !defined(ORT_MINIMAL_BUILD) IOnnxRuntimeOpSchemaCollectionPtr schema_registry_; diff --git a/onnxruntime/core/flatbuffers/flatbuffers_utils.cc b/onnxruntime/core/flatbuffers/flatbuffers_utils.cc index 61a9524f9e..9c8995e23f 100644 --- a/onnxruntime/core/flatbuffers/flatbuffers_utils.cc +++ b/onnxruntime/core/flatbuffers/flatbuffers_utils.cc @@ -197,21 +197,26 @@ static Status LoadTensorDimensionOrtFormat(const fbs::Dimension& fbs_dim, return Status::OK(); } +static Status LoadTensorShapeOrtFormat(const fbs::Shape& fbs_shape, TensorShapeProto& shape_proto) { + auto fbs_dims = fbs_shape.dim(); + if (fbs_dims) { + auto dims = shape_proto.mutable_dim(); + dims->Reserve(fbs_dims->size()); + for (const auto fbs_dim : *fbs_dims) { + ORT_RETURN_IF(nullptr == fbs_dim, "Null entry in dimensions. Invalid ORT format model."); + TensorShapeProto_Dimension dim; + ORT_RETURN_IF_ERROR(LoadTensorDimensionOrtFormat(*fbs_dim, *dims->Add())); + } + } + return Status::OK(); +} + static Status LoadTensorTypeAndShapeOrtFormat(const fbs::TensorTypeAndShape& fbs_tensor_type, TypeProto_Tensor& tensor_type_proto) { tensor_type_proto.set_elem_type(static_cast(fbs_tensor_type.elem_type())); auto fbs_shape = fbs_tensor_type.shape(); if (fbs_shape) { - auto fbs_dims = fbs_shape->dim(); - if (fbs_dims) { - auto dims = tensor_type_proto.mutable_shape()->mutable_dim(); - dims->Reserve(fbs_dims->size()); - for (const auto fbs_dim : *fbs_dims) { - ORT_RETURN_IF(nullptr == fbs_dim, "Null entry in dimensions. Invalid ORT format model."); - TensorShapeProto_Dimension dim; - ORT_RETURN_IF_ERROR(LoadTensorDimensionOrtFormat(*fbs_dim, *dims->Add())); - } - } + ORT_RETURN_IF_ERROR(LoadTensorShapeOrtFormat(*fbs_shape, *tensor_type_proto.mutable_shape())); } return Status::OK(); } diff --git a/onnxruntime/core/flatbuffers/schema/README.md b/onnxruntime/core/flatbuffers/schema/README.md index f3938d8388..6815a1e19c 100644 --- a/onnxruntime/core/flatbuffers/schema/README.md +++ b/onnxruntime/core/flatbuffers/schema/README.md @@ -16,3 +16,12 @@ Change to the directory containing this file (onnxruntime/core/flatbuffers) and `> ..\..\..\build\Windows\Debug\external\flatbuffers\Debug\flatc.exe --cpp --scoped-enums --filename-suffix .fbs ort.fbs` This should result in ort.fbs.h being updated. + +# ORT FB format version history +`See onnxruntime/core/session/inference_session.cc:IsOrtModelVersionSupported()` for version array and `kOrtModelVersion` for currently supported version. + +## Version 1. History begins +Initial support for FlatBuffers that includes Model support. Graph support including Attributes, Tensors, Tensor Sequences, Maps and Sequences. Constant initializers are also supported. Constant nodes are converted to constant initializers in the ORT format. + +## Version 2. +Support for sparse initialiers. Sparse intializers are stored within ORT FlatBuffers format, which includes sparse initializers converted from Constant node attribute. \ No newline at end of file diff --git a/onnxruntime/core/flatbuffers/schema/ort.fbs b/onnxruntime/core/flatbuffers/schema/ort.fbs index 6e7a41acfe..cb111247d9 100644 --- a/onnxruntime/core/flatbuffers/schema/ort.fbs +++ b/onnxruntime/core/flatbuffers/schema/ort.fbs @@ -121,7 +121,7 @@ table ValueInfo { type:TypeInfo; } -// TODO add support of SparseTensor/Opaque +// TODO add support of SparseTensor, Opaque if needed union TypeInfoValue { tensor_type:TensorTypeAndShape, sequence_type:SequenceType, @@ -155,6 +155,12 @@ table Tensor { string_data:[string]; } +table SparseTensor { + values:Tensor; + indices:Tensor; + dims:[int64]; +} + table Attribute{ name:string; doc_string:string; @@ -185,6 +191,7 @@ table Graph{ inputs:[string]; outputs:[string]; + sparse_initializers:[SparseTensor]; } table Model { @@ -219,7 +226,7 @@ table SessionState { table InferenceSession { // This is the ORT format model version // The version number is defined as kOrtModelVersion in /onnxruntime/core/session/inference_session.cc - // Please update it when there is a change to this schema which will break the compatibilites + // Please update it when there is a change to this schema which will break the compatibilities ort_version:string; model:Model; diff --git a/onnxruntime/core/flatbuffers/schema/ort.fbs.h b/onnxruntime/core/flatbuffers/schema/ort.fbs.h index 1516f49cc5..3da4d060c9 100644 --- a/onnxruntime/core/flatbuffers/schema/ort.fbs.h +++ b/onnxruntime/core/flatbuffers/schema/ort.fbs.h @@ -48,6 +48,9 @@ struct OperatorSetIdBuilder; struct Tensor; struct TensorBuilder; +struct SparseTensor; +struct SparseTensorBuilder; + struct Attribute; struct AttributeBuilder; @@ -332,10 +335,8 @@ FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) EdgeEnd FLATBUFFERS_FINAL_CLASS { int32_t dst_arg_index_; public: - EdgeEnd() - : node_index_(0), - src_arg_index_(0), - dst_arg_index_(0) { + EdgeEnd() { + memset(static_cast(this), 0, sizeof(EdgeEnd)); } EdgeEnd(uint32_t _node_index, int32_t _src_arg_index, int32_t _dst_arg_index) : node_index_(flatbuffers::EndianScalar(_node_index)), @@ -382,6 +383,7 @@ struct ShapeBuilder { : fbb_(_fbb) { start_ = fbb_.StartTable(); } + ShapeBuilder &operator=(const ShapeBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -442,6 +444,7 @@ struct DimensionBuilder { : fbb_(_fbb) { start_ = fbb_.StartTable(); } + DimensionBuilder &operator=(const DimensionBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -513,6 +516,7 @@ struct DimensionValueBuilder { : fbb_(_fbb) { start_ = fbb_.StartTable(); } + DimensionValueBuilder &operator=(const DimensionValueBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -580,6 +584,7 @@ struct TensorTypeAndShapeBuilder { : fbb_(_fbb) { start_ = fbb_.StartTable(); } + TensorTypeAndShapeBuilder &operator=(const TensorTypeAndShapeBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -632,6 +637,7 @@ struct MapTypeBuilder { : fbb_(_fbb) { start_ = fbb_.StartTable(); } + MapTypeBuilder &operator=(const MapTypeBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -676,6 +682,7 @@ struct SequenceTypeBuilder { : fbb_(_fbb) { start_ = fbb_.StartTable(); } + SequenceTypeBuilder &operator=(const SequenceTypeBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -735,6 +742,7 @@ struct NodeEdgeBuilder { : fbb_(_fbb) { start_ = fbb_.StartTable(); } + NodeEdgeBuilder &operator=(const NodeEdgeBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -904,6 +912,7 @@ struct NodeBuilder { : fbb_(_fbb) { start_ = fbb_.StartTable(); } + NodeBuilder &operator=(const NodeBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -1030,6 +1039,7 @@ struct ValueInfoBuilder { : fbb_(_fbb) { start_ = fbb_.StartTable(); } + ValueInfoBuilder &operator=(const ValueInfoBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -1129,6 +1139,7 @@ struct TypeInfoBuilder { : fbb_(_fbb) { start_ = fbb_.StartTable(); } + TypeInfoBuilder &operator=(const TypeInfoBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -1196,6 +1207,7 @@ struct OperatorSetIdBuilder { : fbb_(_fbb) { start_ = fbb_.StartTable(); } + OperatorSetIdBuilder &operator=(const OperatorSetIdBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -1296,6 +1308,7 @@ struct TensorBuilder { : fbb_(_fbb) { start_ = fbb_.StartTable(); } + TensorBuilder &operator=(const TensorBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -1344,6 +1357,84 @@ inline flatbuffers::Offset CreateTensorDirect( string_data__); } +struct SparseTensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SparseTensorBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_VALUES = 4, + VT_INDICES = 6, + VT_DIMS = 8 + }; + const onnxruntime::experimental::fbs::Tensor *values() const { + return GetPointer(VT_VALUES); + } + const onnxruntime::experimental::fbs::Tensor *indices() const { + return GetPointer(VT_INDICES); + } + const flatbuffers::Vector *dims() const { + return GetPointer *>(VT_DIMS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_VALUES) && + verifier.VerifyTable(values()) && + VerifyOffset(verifier, VT_INDICES) && + verifier.VerifyTable(indices()) && + VerifyOffset(verifier, VT_DIMS) && + verifier.VerifyVector(dims()) && + verifier.EndTable(); + } +}; + +struct SparseTensorBuilder { + typedef SparseTensor Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_values(flatbuffers::Offset values) { + fbb_.AddOffset(SparseTensor::VT_VALUES, values); + } + void add_indices(flatbuffers::Offset indices) { + fbb_.AddOffset(SparseTensor::VT_INDICES, indices); + } + void add_dims(flatbuffers::Offset> dims) { + fbb_.AddOffset(SparseTensor::VT_DIMS, dims); + } + explicit SparseTensorBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SparseTensorBuilder &operator=(const SparseTensorBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSparseTensor( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset values = 0, + flatbuffers::Offset indices = 0, + flatbuffers::Offset> dims = 0) { + SparseTensorBuilder builder_(_fbb); + builder_.add_dims(dims); + builder_.add_indices(indices); + builder_.add_values(values); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateSparseTensorDirect( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset values = 0, + flatbuffers::Offset indices = 0, + const std::vector *dims = nullptr) { + auto dims__ = dims ? _fbb.CreateVector(*dims) : 0; + return onnxruntime::experimental::fbs::CreateSparseTensor( + _fbb, + values, + indices, + dims__); +} + struct Attribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef AttributeBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { @@ -1479,6 +1570,7 @@ struct AttributeBuilder { : fbb_(_fbb) { start_ = fbb_.StartTable(); } + AttributeBuilder &operator=(const AttributeBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -1567,7 +1659,8 @@ struct Graph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_MAX_NODE_INDEX = 10, VT_NODE_EDGES = 12, VT_INPUTS = 14, - VT_OUTPUTS = 16 + VT_OUTPUTS = 16, + VT_SPARSE_INITIALIZERS = 18 }; const flatbuffers::Vector> *initializers() const { return GetPointer> *>(VT_INITIALIZERS); @@ -1590,6 +1683,9 @@ struct Graph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector> *outputs() const { return GetPointer> *>(VT_OUTPUTS); } + const flatbuffers::Vector> *sparse_initializers() const { + return GetPointer> *>(VT_SPARSE_INITIALIZERS); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_INITIALIZERS) && @@ -1611,6 +1707,9 @@ struct Graph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_OUTPUTS) && verifier.VerifyVector(outputs()) && verifier.VerifyVectorOfStrings(outputs()) && + VerifyOffset(verifier, VT_SPARSE_INITIALIZERS) && + verifier.VerifyVector(sparse_initializers()) && + verifier.VerifyVectorOfTables(sparse_initializers()) && verifier.EndTable(); } }; @@ -1640,10 +1739,14 @@ struct GraphBuilder { void add_outputs(flatbuffers::Offset>> outputs) { fbb_.AddOffset(Graph::VT_OUTPUTS, outputs); } + void add_sparse_initializers(flatbuffers::Offset>> sparse_initializers) { + fbb_.AddOffset(Graph::VT_SPARSE_INITIALIZERS, sparse_initializers); + } explicit GraphBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } + GraphBuilder &operator=(const GraphBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -1659,8 +1762,10 @@ inline flatbuffers::Offset CreateGraph( uint32_t max_node_index = 0, flatbuffers::Offset>> node_edges = 0, flatbuffers::Offset>> inputs = 0, - flatbuffers::Offset>> outputs = 0) { + flatbuffers::Offset>> outputs = 0, + flatbuffers::Offset>> sparse_initializers = 0) { GraphBuilder builder_(_fbb); + builder_.add_sparse_initializers(sparse_initializers); builder_.add_outputs(outputs); builder_.add_inputs(inputs); builder_.add_node_edges(node_edges); @@ -1679,13 +1784,15 @@ inline flatbuffers::Offset CreateGraphDirect( uint32_t max_node_index = 0, const std::vector> *node_edges = nullptr, const std::vector> *inputs = nullptr, - const std::vector> *outputs = nullptr) { + const std::vector> *outputs = nullptr, + const std::vector> *sparse_initializers = nullptr) { auto initializers__ = initializers ? _fbb.CreateVector>(*initializers) : 0; auto node_args__ = node_args ? _fbb.CreateVector>(*node_args) : 0; auto nodes__ = nodes ? _fbb.CreateVector>(*nodes) : 0; auto node_edges__ = node_edges ? _fbb.CreateVector>(*node_edges) : 0; auto inputs__ = inputs ? _fbb.CreateVector>(*inputs) : 0; auto outputs__ = outputs ? _fbb.CreateVector>(*outputs) : 0; + auto sparse_initializers__ = sparse_initializers ? _fbb.CreateVector>(*sparse_initializers) : 0; return onnxruntime::experimental::fbs::CreateGraph( _fbb, initializers__, @@ -1694,7 +1801,8 @@ inline flatbuffers::Offset CreateGraphDirect( max_node_index, node_edges__, inputs__, - outputs__); + outputs__, + sparse_initializers__); } struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -1786,6 +1894,7 @@ struct ModelBuilder { : fbb_(_fbb) { start_ = fbb_.StartTable(); } + ModelBuilder &operator=(const ModelBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -1878,6 +1987,7 @@ struct KernelCreateInfosBuilder { : fbb_(_fbb) { start_ = fbb_.StartTable(); } + KernelCreateInfosBuilder &operator=(const KernelCreateInfosBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -1949,6 +2059,7 @@ struct SubGraphSessionStateBuilder { : fbb_(_fbb) { start_ = fbb_.StartTable(); } + SubGraphSessionStateBuilder &operator=(const SubGraphSessionStateBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -2015,6 +2126,7 @@ struct SessionStateBuilder { : fbb_(_fbb) { start_ = fbb_.StartTable(); } + SessionStateBuilder &operator=(const SessionStateBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -2088,6 +2200,7 @@ struct InferenceSessionBuilder { : fbb_(_fbb) { start_ = fbb_.StartTable(); } + InferenceSessionBuilder &operator=(const InferenceSessionBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); diff --git a/onnxruntime/core/framework/data_types.cc b/onnxruntime/core/framework/data_types.cc index 5d99b07bf2..fe445c0de7 100644 --- a/onnxruntime/core/framework/data_types.cc +++ b/onnxruntime/core/framework/data_types.cc @@ -34,13 +34,11 @@ MLDataType DataTypeImpl::GetType() { namespace onnxruntime { -#if !defined(ORT_MINIMAL_BUILD) // Return the MLDataType used for a generic SparseTensor template <> MLDataType DataTypeImpl::GetType() { return SparseTensorTypeBase::Type(); } -#endif template <> MLDataType DataTypeImpl::GetType() { @@ -59,12 +57,9 @@ struct TensorElementTypeSetter { static void SetTensorElementType(ONNX_NAMESPACE::TypeProto& proto) { proto.mutable_tensor_type()->set_elem_type(utils::ToTensorProtoElementType()); } - -#if !defined(ORT_MINIMAL_BUILD) static void SetSparseTensorElementType(ONNX_NAMESPACE::TypeProto& proto) { proto.mutable_sparse_tensor_type()->set_elem_type(utils::ToTensorProtoElementType()); } -#endif #if !defined(DISABLE_ML_OPS) static void SetMapKeyType(ONNX_NAMESPACE::TypeProto& proto) { @@ -129,10 +124,8 @@ void AssignOpaqueDomainName(const char* domain, const char* name, bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Tensor& tensor_proto, const ONNX_NAMESPACE::TypeProto_Tensor& type_proto); -#if !defined(ORT_MINIMAL_BUILD) bool IsCompatible(const ONNX_NAMESPACE::TypeProto_SparseTensor& tensor_proto, const ONNX_NAMESPACE::TypeProto_SparseTensor& type_proto); -#endif #if !defined(DISABLE_ML_OPS) bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Map& map_proto, @@ -175,11 +168,9 @@ bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Map& map_proto, case TypeProto::ValueCase::kOpaqueType: result = IsCompatible(lhs.value_type().opaque_type(), rhs.value_type().opaque_type()); break; -#if !defined(ORT_MINIMAL_BUILD) case TypeProto::ValueCase::kSparseTensorType: result = IsCompatible(lhs.value_type().sparse_tensor_type(), rhs.value_type().sparse_tensor_type()); break; -#endif default: ORT_ENFORCE(false); break; @@ -212,11 +203,9 @@ bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Sequence& sequence_proto, case TypeProto::ValueCase::kOpaqueType: result = IsCompatible(lhs.elem_type().opaque_type(), rhs.elem_type().opaque_type()); break; -#if !defined(ORT_MINIMAL_BUILD) case TypeProto::ValueCase::kSparseTensorType: result = IsCompatible(lhs.elem_type().sparse_tensor_type(), rhs.elem_type().sparse_tensor_type()); break; -#endif default: ORT_ENFORCE(false); break; @@ -226,7 +215,6 @@ bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Sequence& sequence_proto, } return result; } - bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Opaque& opaque_proto, const ONNX_NAMESPACE::TypeProto_Opaque& type_proto) { const auto& lhs = opaque_proto; @@ -246,12 +234,10 @@ bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Opaque& opaque_proto, (lhs_name && rhs_name && lhs.name() != rhs.name())); } -#if !defined(ORT_MINIMAL_BUILD) bool IsCompatible(const ONNX_NAMESPACE::TypeProto_SparseTensor& tensor_proto, const ONNX_NAMESPACE::TypeProto_SparseTensor& type_proto) { return type_proto.elem_type() == tensor_proto.elem_type(); } -#endif void RegisterAllProtos(const std::function& /*reg_fn*/); @@ -370,8 +356,6 @@ MLDataType TensorTypeBase::Type() { return &tensor_base; } -#if !defined(ORT_MINIMAL_BUILD) - /// SparseTensor struct SparseTensorTypeBase::Impl : public data_types_internal::TypeProtoImpl { @@ -417,7 +401,6 @@ MLDataType SparseTensorTypeBase::Type() { static SparseTensorTypeBase sparse_tensor_base; return &sparse_tensor_base; } -#endif // !defined(ORT_MINIMAL_BUILD) ///// SequenceTensorTypeBase @@ -549,7 +532,6 @@ ORT_REGISTER_TENSOR_TYPE(uint64_t); ORT_REGISTER_TENSOR_TYPE(MLFloat16); ORT_REGISTER_TENSOR_TYPE(BFloat16); -#if !defined(ORT_MINIMAL_BUILD) ORT_REGISTER_SPARSE_TENSOR_TYPE(int32_t); ORT_REGISTER_SPARSE_TENSOR_TYPE(float); ORT_REGISTER_SPARSE_TENSOR_TYPE(bool); @@ -564,7 +546,6 @@ ORT_REGISTER_SPARSE_TENSOR_TYPE(uint32_t); ORT_REGISTER_SPARSE_TENSOR_TYPE(uint64_t); ORT_REGISTER_SPARSE_TENSOR_TYPE(MLFloat16); ORT_REGISTER_SPARSE_TENSOR_TYPE(BFloat16); -#endif #if !defined(DISABLE_ML_OPS) ORT_REGISTER_MAP(MapStringToString); @@ -610,13 +591,11 @@ ORT_REGISTER_SEQ(VectorMapInt64ToFloat); reg_fn(mltype); \ } -#if !defined(ORT_MINIMAL_BUILD) #define REGISTER_SPARSE_TENSOR_PROTO(TYPE, reg_fn) \ { \ MLDataType mltype = DataTypeImpl::GetSparseTensorType(); \ reg_fn(mltype); \ } -#endif #define REGISTER_ONNX_PROTO(TYPE, reg_fn) \ { \ @@ -642,7 +621,6 @@ void RegisterAllProtos(const std::function& reg_fn) { REGISTER_TENSOR_PROTO(MLFloat16, reg_fn); REGISTER_TENSOR_PROTO(BFloat16, reg_fn); -#if !defined(ORT_MINIMAL_BUILD) REGISTER_SPARSE_TENSOR_PROTO(int32_t, reg_fn); REGISTER_SPARSE_TENSOR_PROTO(float, reg_fn); REGISTER_SPARSE_TENSOR_PROTO(bool, reg_fn); @@ -657,7 +635,6 @@ void RegisterAllProtos(const std::function& reg_fn) { REGISTER_SPARSE_TENSOR_PROTO(uint64_t, reg_fn); REGISTER_SPARSE_TENSOR_PROTO(MLFloat16, reg_fn); REGISTER_SPARSE_TENSOR_PROTO(BFloat16, reg_fn); -#endif #if !defined(DISABLE_ML_OPS) REGISTER_ONNX_PROTO(MapStringToString, reg_fn); @@ -820,7 +797,6 @@ const NonTensorTypeBase* DataTypeImpl::SequenceTensorTypeFromONNXEnum(int type) } } -#if !defined(ORT_MINIMAL_BUILD) const SparseTensorTypeBase* DataTypeImpl::SparseTensorTypeFromONNXEnum(int type) { switch (type) { case TensorProto_DataType_FLOAT: @@ -855,7 +831,6 @@ const SparseTensorTypeBase* DataTypeImpl::SparseTensorTypeFromONNXEnum(int type) ORT_NOT_IMPLEMENTED("sparse tensor type ", type, " is not supported"); } } -#endif MLDataType DataTypeImpl::TypeFromProto(const ONNX_NAMESPACE::TypeProto& proto) { const auto& registry = data_types_internal::DataTypeRegistry::instance(); @@ -1030,7 +1005,6 @@ ContainerChecker::ContainerChecker(MLDataType ml_type) { types_.emplace_back(ContainerType::kTensor, type_proto->tensor_type().elem_type()); type_proto = nullptr; break; - #if !defined(DISABLE_ML_OPS) case TypeProto::ValueCase::kMapType: { const auto& map_type = type_proto->map_type(); @@ -1043,7 +1017,6 @@ ContainerChecker::ContainerChecker(MLDataType ml_type) { types_.emplace_back(ContainerType::kSequence, TensorProto_DataType_UNDEFINED); type_proto = &type_proto->sequence_type().elem_type(); break; - case TypeProto::ValueCase::kOpaqueType: // We do not handle this and terminate here types_.emplace_back(ContainerType::kOpaque, diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index 694e137ab7..f9184c9eaa 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -478,7 +478,6 @@ static Status AllocateTensorSequence(OrtValue& ort_value) { return Status::OK(); } -#if !defined(ORT_MINIMAL_BUILD) static Status AllocateSparseTensor(MLValue& mlvalue, const DataTypeImpl& ml_type, AllocatorPtr allocator, const TensorShape& shape, size_t nnz, bool create_fence, const SessionState& session_state) { @@ -496,7 +495,6 @@ static Status AllocateSparseTensor(MLValue& mlvalue, const DataTypeImpl& ml_type return Status::OK(); } -#endif // This method is not thread safe! Status ExecutionFrame::AllocateAsPerAllocationPlan(OrtValue& ort_value, int ort_value_index, const TensorShape* shape, @@ -570,13 +568,8 @@ Status ExecutionFrame::AllocateAsPerAllocationPlan(OrtValue& ort_value, int ort_ return Status::OK(); } else if (ml_type->IsSparseTensorType()) { -#if !defined(ORT_MINIMAL_BUILD) return AllocateSparseTensor(ort_value, *ml_type, GetAllocator(alloc_info), *shape, nnz, per_alloc_plan.create_fence_if_async, session_state_); -#else - // Model load should have failed so this should be unreachable - ORT_THROW("SparseTensor is not supported in this build."); -#endif } else if (ml_type->IsTensorSequenceType()) { return AllocateTensorSequence(ort_value); } else { diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc index 2d5e128271..145f6fbab5 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc @@ -20,9 +20,7 @@ using onnxruntime::BFloat16; using onnxruntime::DataTypeImpl; using onnxruntime::MLFloat16; -#if !defined(ORT_MINIMAL_BUILD) using onnxruntime::SparseTensor; -#endif using onnxruntime::Tensor; using onnxruntime::TensorShape; @@ -121,7 +119,6 @@ OrtStatus* OrtTypeInfo::FromOrtValue(const OrtValue& value, OrtTypeInfo** out) { } if (type->IsSparseTensorType()) { -#if !defined(ORT_MINIMAL_BUILD) OrtTensorTypeAndShapeInfo* info = nullptr; const SparseTensor& tensor = value.Get(); const auto* tensor_data_type = tensor.Values().DataType(); @@ -131,9 +128,6 @@ OrtStatus* OrtTypeInfo::FromOrtValue(const OrtValue& value, OrtTypeInfo** out) { } *out = new OrtTypeInfo(ONNX_TYPE_SPARSETENSOR, info); return nullptr; -#else - return OrtApis::CreateStatus(ORT_FAIL, "SparseTensor is not supported in this build."); -#endif } if (type->IsTensorSequenceType()) { @@ -163,11 +157,9 @@ OrtStatus* OrtTypeInfo::FromOrtValue(const OrtValue& value, OrtTypeInfo** out) { *out = new OrtTypeInfo(ONNX_TYPE_OPAQUE); return nullptr; } -#if !defined(DISABLE_ML_OPS) case on::TypeProto::kMapType: { return OrtTypeInfo::FromTypeProto(type_proto, out); } -#endif case on::TypeProto::kSequenceType: { return OrtTypeInfo::FromTypeProto(type_proto, out); } @@ -228,6 +220,7 @@ OrtStatus* OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* input, Or case on::TypeProto::kSparseTensorType: { ONNXType ten_type = ONNX_TYPE_UNKNOWN; const on::TypeProto_Tensor* tensor_type = nullptr; + const on::TypeProto_SparseTensor* sparse_type = nullptr; const on::TensorShapeProto* sp = nullptr; if (value_case == on::TypeProto::kTensorType) { tensor_type = &input->tensor_type(); @@ -236,15 +229,11 @@ OrtStatus* OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* input, Or sp = &tensor_type->shape(); } } else if (value_case == on::TypeProto::kSparseTensorType) { -#if !defined(ORT_MINIMAL_BUILD) - const on::TypeProto_SparseTensor* sparse_type = &input->sparse_tensor_type(); + sparse_type = &input->sparse_tensor_type(); ten_type = ONNX_TYPE_SPARSETENSOR; if (onnxruntime::utils::HasShape(*sparse_type)) { sp = &sparse_type->shape(); } -#else - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Sparse tensors are not supported in this build."); -#endif } OrtStatus* st = nullptr; @@ -279,7 +268,7 @@ OrtStatus* OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* input, Or type_info->denotation = input->denotation(); *out = type_info; return nullptr; - } + } break; case on::TypeProto::kSequenceType: { OrtSequenceTypeInfo* sequence_type_info = nullptr; @@ -291,9 +280,8 @@ OrtStatus* OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* input, Or type_info->denotation = input->denotation(); *out = type_info; return nullptr; - } + } break; case on::TypeProto::kMapType: { -#if !defined(DISABLE_ML_OPS) OrtMapTypeInfo* map_type_info = nullptr; if (auto status = OrtMapTypeInfo::FromTypeProto(input, &map_type_info)) { @@ -304,16 +292,13 @@ OrtStatus* OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* input, Or type_info->denotation = input->denotation(); *out = type_info; return nullptr; -#else - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Map data types are not supported in this build."); -#endif - } + } break; case on::TypeProto::kOpaqueType: { auto type_info = new OrtTypeInfo(ONNX_TYPE_OPAQUE); type_info->denotation = input->denotation(); *out = type_info; return nullptr; - } + } break; case on::TypeProto::VALUE_NOT_SET: break; default: @@ -363,4 +348,4 @@ OrtStatus* OrtTypeInfo::Clone(OrtTypeInfo** out) { break; } return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "not implemented"); -} +} \ No newline at end of file diff --git a/onnxruntime/core/framework/op_kernel.cc b/onnxruntime/core/framework/op_kernel.cc index 02835139e3..b3c05bb4e5 100644 --- a/onnxruntime/core/framework/op_kernel.cc +++ b/onnxruntime/core/framework/op_kernel.cc @@ -33,12 +33,10 @@ Tensor* OpKernelContext::Output(int index, const std::initializer_list& return Output(index, TensorShape(shape)); } -#if !defined(ORT_MINIMAL_BUILD) SparseTensor* OpKernelContext::Output(int index, size_t nnz, const TensorShape& shape) { auto p_ml_value = OutputMLValue(index, shape, nnz); return p_ml_value ? p_ml_value->GetMutable() : nullptr; } -#endif bool OpKernelContext::TryGetInferredInputShape(int index, TensorShape& shape) const { return execution_frame_->TryGetInferredShape(GetInputArgIndex(index), shape); diff --git a/onnxruntime/core/framework/sparse_tensor.cc b/onnxruntime/core/framework/sparse_tensor.cc index 0d07dc1799..6ea2f026a3 100644 --- a/onnxruntime/core/framework/sparse_tensor.cc +++ b/onnxruntime/core/framework/sparse_tensor.cc @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if !defined(ORT_MINIMAL_BUILD) - #include "core/framework/data_types.h" #include "core/framework/sparse_tensor.h" @@ -33,5 +31,3 @@ SparseTensor::SparseTensor(MLDataType elt_type, shape_(shape) {} } // namespace onnxruntime - -#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index c780efcecf..c20945a767 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -18,10 +18,8 @@ using onnxruntime::BFloat16; using onnxruntime::DataTypeImpl; using onnxruntime::MLFloat16; -using onnxruntime::Tensor; -#if !defined(ORT_MINIMAL_BUILD) using onnxruntime::SparseTensor; -#endif +using onnxruntime::Tensor; ORT_API_STATUS_IMPL(OrtApis::CreateTensorTypeAndShapeInfo, _Outptr_ OrtTensorTypeAndShapeInfo** out) { API_IMPL_BEGIN @@ -203,25 +201,22 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Out API_IMPL_BEGIN onnxruntime::MLDataType type = v->Type(); ORT_ENFORCE(type != nullptr, "OrtValue is not a Tensor"); - const onnxruntime::TensorShape* shape = nullptr; - onnxruntime::MLDataType data_type = nullptr; - if (type->IsTensorType()) { - const Tensor& tensor = v->Get(); - shape = &tensor.Shape(); - data_type = tensor.DataType(); - } else if (type->IsSparseTensorType()) { -#if !defined(ORT_MINIMAL_BUILD) - const SparseTensor& tensor = v->Get(); - shape = &tensor.Shape(); - data_type = tensor.Values().DataType(); -#else - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "SparseTensor is not supported in this build."); -#endif + if (type->IsTensorType() || type->IsSparseTensorType()) { + const onnxruntime::TensorShape* shape = nullptr; + onnxruntime::MLDataType data_type = nullptr; + if (type->IsTensorType()) { + const Tensor& tensor = v->Get(); + shape = &tensor.Shape(); + data_type = tensor.DataType(); + } else { + const SparseTensor& tensor = v->Get(); + shape = &tensor.Shape(); + data_type = tensor.Values().DataType(); + } + return GetTensorShapeAndType(*shape, *data_type, out); } else { ORT_THROW("Argument is not a tensor"); } - - return GetTensorShapeAndType(*shape, *data_type, out); API_IMPL_END } diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 34a3c04d4b..9f5e7583ea 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -637,13 +637,11 @@ common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& n *tensor.mutable_string_data() = constant_attribute.strings(); break; } -#if !defined(ORT_MINIMAL_BUILD) case AttributeProto_AttributeType_SPARSE_TENSOR: { auto& s = constant_attribute.sparse_tensor(); ORT_RETURN_IF_ERROR(SparseTensorProtoToDenseTensorProto(s, tensor)); break; } -#endif default: ORT_THROW("Unsupported attribute value type of ", constant_attribute.type(), " in 'Constant' node '", node.name(), "'"); @@ -663,7 +661,19 @@ static Status CopySparseData(size_t n_sparse_elements, Status status = Status::OK(); TensorShape indices_shape(indices.dims().data(), indices.dims().size()); - auto indices_data = gsl::make_span(indices.int64_data().data(), static_cast(indices_shape.Size())); + ORT_RETURN_IF_NOT(indices.data_type() == ONNX_NAMESPACE ::TensorProto_DataType_INT64, "Indicies expected to be INT64"); + + gsl::span indices_data; + const auto elements = static_cast(indices_shape.Size()); + if (indices.int64_data_size() > 0) { + indices_data = gsl::make_span(indices.int64_data().data(), elements); + } else if (indices.has_raw_data()) { + ORT_RETURN_IF_NOT(indices.raw_data().size() == (elements * sizeof(int64_t)), + "Sparse Indicies raw data size does not match expected."); + indices_data = gsl::make_span(reinterpret_cast(indices.raw_data().data()), elements); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Invalid SparseTensor indices. Should either have raw or int64 data"); + } if (indices_shape.NumDimensions() == 1) { // flattened indexes @@ -707,7 +717,20 @@ static Status CopySparseData(size_t n_sparse_elements, return status; } -#if !defined(ORT_MINIMAL_BUILD) +struct UnsupportedSparseDataType { + Status operator()(int32_t dt_type) const { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported sparse tensor data type of ", dt_type); + } +}; + +template +struct GetElementSize { + Status operator()(size_t& element_size) const { + element_size = sizeof(T); + return Status::OK(); + } +}; + common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseTensorProto& sparse, ONNX_NAMESPACE::TensorProto& dense) { Status status = Status::OK(); @@ -715,6 +738,7 @@ common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseT const auto& sparse_values = sparse.values(); auto type = sparse_values.data_type(); dense.set_data_type(type); + *dense.mutable_name() = sparse_values.name(); SafeInt n_sparse_elements = 1; for (auto dim : sparse_values.dims()) { @@ -730,60 +754,37 @@ common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseT const auto& indices = sparse.indices(); auto dims = gsl::make_span(dense.dims().data(), dense.dims().size()); - // need to read in sparse data first as it could be in a type specific field, in raw data, or in external data - size_t sparse_bytes = 0; - ORT_RETURN_IF_ERROR(GetSizeInBytesFromTensorProto<0>(sparse_values, &sparse_bytes)); - if (type != TensorProto_DataType_STRING) { - std::vector sparse_data_storage(sparse_bytes, 0); - void* sparse_data = sparse_data_storage.data(); - + // need to read in sparse data first as it could be in a type specific field, in raw data, or in external data + size_t sparse_bytes = 0; + std::unique_ptr sparse_data_storage; + ORT_RETURN_IF_ERROR(UnpackInitializerData(sparse_values, sparse_data_storage, sparse_bytes)); + void* sparse_data = sparse_data_storage.get(); size_t element_size = 0; - - // setup buffer for output - switch (type) { - case TensorProto_DataType_FLOAT: { - element_size = sizeof(float); - UnpackTensor(sparse_values, static_cast(sparse_data), n_sparse_elements); - break; - } - case TensorProto_DataType_INT64: { - element_size = sizeof(int64_t); - UnpackTensor(sparse_values, static_cast(sparse_data), n_sparse_elements); - break; - } - case TensorProto_DataType_INT32: { - element_size = sizeof(int32_t); - UnpackTensor(sparse_values, static_cast(sparse_data), n_sparse_elements); - break; - } - case TensorProto_DataType_DOUBLE: { - element_size = sizeof(double); - UnpackTensor(sparse_values, static_cast(sparse_data), n_sparse_elements); - break; - } - case TensorProto_DataType_UINT32: { - element_size = sizeof(uint32_t); - UnpackTensor(sparse_values, static_cast(sparse_data), n_sparse_elements); - break; - } - case TensorProto_DataType_UINT64: { - element_size = sizeof(uint64_t); - UnpackTensor(sparse_values, static_cast(sparse_data), n_sparse_elements); - break; - } - default: - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported sparse tensor data type of ", type); - } + // We want to this list to match the one used below in DenseTensorToSparseTensorProto() + MLTypeCallDispatcherRet type_disp(type); + ORT_RETURN_IF_ERROR(type_disp.InvokeWithUnsupportedPolicy(element_size)); // by putting the data into a std::string we can avoid a copy as set_raw_data can do a std::move // into the TensorProto. however to actually write to the buffer we have created in the std::string we need // this somewhat dirty hack to get a mutable pointer. we could alternatively use &dense_data_storage.front() // but using const_cast makes it more obvious we're doing something ugly. + // C++17 add non-const data() where we could remove const_cast std::string dense_data_storage(n_dense_elements * element_size, 0); void* dense_data = const_cast(dense_data_storage.data()); switch (element_size) { + case 1: { + auto dense_data_span = gsl::make_span(static_cast(dense_data), n_dense_elements); + status = CopySparseData( + n_sparse_elements, + indices, dims, + [sparse_data, dense_data_span](size_t from_idx, size_t to_idx) { + dense_data_span[to_idx] = static_cast(sparse_data)[from_idx]; + }); + + break; + } case 4: { auto dense_data_span = gsl::make_span(static_cast(dense_data), n_dense_elements); status = CopySparseData( @@ -795,45 +796,127 @@ common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseT break; } - case 8: { - auto dense_data_span = gsl::make_span(static_cast(dense_data), n_dense_elements); - status = CopySparseData( - n_sparse_elements, - indices, dims, - [sparse_data, dense_data_span](size_t from_idx, size_t to_idx) { - dense_data_span[to_idx] = static_cast(sparse_data)[from_idx]; - }); - - break; - } + default: + ORT_THROW(false, "BUG! Report to onnxruntime team."); } + ORT_RETURN_IF_ERROR(status); dense.set_raw_data(std::move(dense_data_storage)); } else { - // strings need to be handled differently as they can't use raw data (as per ONNX rules) - std::vector sparse_data(n_sparse_elements); - UnpackTensor(sparse_values, sparse_data.data(), n_sparse_elements); - - // RepeatedPtrField doesn't have a Resize method so manually add elements - auto dense_strings = dense.mutable_string_data(); - dense_strings->Reserve(n_dense_elements); - for (int64_t j = 0; j < n_dense_elements; ++j) { - dense_strings->Add(""); - } - - status = CopySparseData( - n_sparse_elements, - indices, dims, - [&sparse_values, &dense_strings](size_t from_idx, size_t to_idx) { - const std::string& input = sparse_values.string_data()[SafeInt(from_idx)]; - *dense_strings->Mutable(SafeInt(to_idx)) = input; - }); + // No request for std::string + status = UnsupportedSparseDataType()(ONNX_NAMESPACE::TensorProto_DataType_STRING); } - return status; } -#endif // !defined(ORT_MINIMAL_BUILD) + + +#if !defined (ORT_MINIMAL_BUILD) +// Determines if this is a type specific zero +using IsZeroFunc = bool (*)(const void*); +// Copy element +using CopyElementFunc = void (*)(void* dest, const void* src, int64_t dest_index, int64_t src_index); + +static void SparsifyGeneric(const void* dense_raw_data, size_t n_dense_elements, size_t element_size, + IsZeroFunc is_zero, CopyElementFunc copy, + TensorProto& values, TensorProto& indices) { + + auto advance = [element_size](const void* start, size_t elements) -> const void* { + return (reinterpret_cast(start) + elements * element_size); + }; + + const auto* cbegin = dense_raw_data; + const auto* const cend = advance(cbegin, n_dense_elements); + auto& indices_data = *indices.mutable_int64_data(); + int64_t index = 0; + while (cbegin != cend) { + if (!is_zero(cbegin)) { + indices_data.Add(index); + } + ++index; + cbegin = advance(cbegin, 1U); + } + + auto& raw_data = *values.mutable_raw_data(); + raw_data.resize(indices.int64_data_size() * element_size); + void* data_dest = const_cast(raw_data.data()); + + int64_t dest_index = 0; + for (auto src_index : indices.int64_data()) { + copy(data_dest, dense_raw_data, dest_index, src_index); + ++dest_index; + } +} + +template +bool IsZero(const void* p) { + return (static_cast(0) == *reinterpret_cast(p)); +} + +template +void CopyElement(void* dst, const void* src, int64_t dst_index, int64_t src_index) { + reinterpret_cast(dst)[dst_index] = reinterpret_cast(src)[src_index]; +} + +common::Status DenseTensorToSparseTensorProto(const ONNX_NAMESPACE::TensorProto& dense_proto, + ONNX_NAMESPACE::SparseTensorProto& result) { + ORT_ENFORCE(HasDataType(dense_proto), "Must have a valid data type"); + + const bool is_string_data = dense_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_STRING; + if (is_string_data) { + return UnsupportedSparseDataType()(ONNX_NAMESPACE::TensorProto_DataType_STRING); + } + + const auto data_type = dense_proto.data_type(); + SparseTensorProto sparse_proto; + auto& values = *sparse_proto.mutable_values(); + values.set_name(dense_proto.name()); + values.set_data_type(data_type); + + auto& indices = *sparse_proto.mutable_indices(); + indices.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + + SafeInt n_dense_elements = 1; + for (auto dim : dense_proto.dims()) { + n_dense_elements *= dim; + } + + size_t tensor_bytes_size = 0; + std::unique_ptr dense_raw_data; + ORT_RETURN_IF_ERROR(UnpackInitializerData(dense_proto, dense_raw_data, tensor_bytes_size)); + size_t element_size = 0; + MLTypeCallDispatcherRet type_disp(data_type); + ORT_RETURN_IF_ERROR(type_disp.InvokeWithUnsupportedPolicy(element_size)); + + switch (element_size) { + case 1: { + // bytes + SparsifyGeneric(dense_raw_data.get(), n_dense_elements, element_size, + IsZero, CopyElement, values, indices); + break; + } + case 4: { + // float + SparsifyGeneric(dense_raw_data.get(), n_dense_elements, element_size, + IsZero, CopyElement, values, indices); + break; + } + default: + ORT_THROW(false, "BUG! Report to onnxruntime team."); + } + + // Fix up shapes + const auto nnz = indices.int64_data_size(); + values.add_dims(nnz); + indices.add_dims(nnz); + + // Save dense shape + *sparse_proto.mutable_dims() = dense_proto.dims(); + swap(result, sparse_proto); + return Status::OK(); +} + +#endif // !ORT_MINIMAL_BUILD template common::Status GetSizeInBytesFromTensorProto<256>(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out); diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 75637cb213..7efda74075 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -66,11 +66,14 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& node, ONNX_NAMESPACE::TensorProto& tensor); -#if !defined(ORT_MINIMAL_BUILD) // Convert a SparseTensorProto to a dense TensorProto common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseTensorProto& sparse, ONNX_NAMESPACE::TensorProto& dense); -#endif + +#if !defined(ORT_MINIMAL_BUILD) +common::Status DenseTensorToSparseTensorProto(const ONNX_NAMESPACE::TensorProto& dense, + ONNX_NAMESPACE::SparseTensorProto& sparse); +#endif // !ORT_MINIMAL_BUILD inline bool HasDimValue(const ONNX_NAMESPACE::TensorShapeProto_Dimension& dim) { return dim.value_case() == ONNX_NAMESPACE::TensorShapeProto_Dimension::kDimValue; @@ -122,6 +125,11 @@ inline bool HasElemType(const ONNX_NAMESPACE::TypeProto_SparseTensor& ten_proto) return ten_proto.elem_type() != ONNX_NAMESPACE::TensorProto::UNDEFINED; } +inline bool HasName(const ONNX_NAMESPACE::SparseTensorProto& ten_proto) { + return ten_proto.values().has_name(); // XXX +} + + inline bool HasKeyType(const ONNX_NAMESPACE::TypeProto_Map& map_proto) { return map_proto.key_type() != ONNX_NAMESPACE::TensorProto::UNDEFINED; } diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index a1c7341113..ec6d117bfc 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -186,14 +186,12 @@ const TensorShapeProto* NodeArg::Shape() const { } return nullptr; } -#if !defined(ORT_MINIMAL_BUILD) case TypeProto::kSparseTensorType: { if (utils::HasShape(type->sparse_tensor_type())) { return &(type->sparse_tensor_type().shape()); } return nullptr; } -#endif case TypeProto::kSequenceType: case TypeProto::kMapType: case TypeProto::kOpaqueType: @@ -793,15 +791,13 @@ ADD_BASIC_ATTR_IMPL(float, AttributeProto_AttributeType::AttributeProto_Attribut ADD_BASIC_ATTR_IMPL(int64_t, AttributeProto_AttributeType::AttributeProto_AttributeType_INT, i) ADD_BASIC_ATTR_IMPL(std::string, AttributeProto_AttributeType::AttributeProto_AttributeType_STRING, s) ADD_ATTR_IMPL(TensorProto, AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR, t) +ADD_ATTR_IMPL(SparseTensorProto, AttributeProto_AttributeType::AttributeProto_AttributeType_SPARSE_TENSOR, sparse_tensor) ADD_LIST_ATTR_IMPL(float, AttributeProto_AttributeType::AttributeProto_AttributeType_FLOATS, floats) ADD_LIST_ATTR_IMPL(int64_t, AttributeProto_AttributeType::AttributeProto_AttributeType_INTS, ints) ADD_LIST_ATTR_IMPL(std::string, AttributeProto_AttributeType::AttributeProto_AttributeType_STRINGS, strings) ADD_LIST_ATTR_IMPL(TensorProto, AttributeProto_AttributeType::AttributeProto_AttributeType_TENSORS, tensors) ADD_LIST_ATTR_IMPL(GraphProto, AttributeProto_AttributeType::AttributeProto_AttributeType_GRAPHS, graphs) -#if !defined(ORT_MINIMAL_BUILD) -ADD_ATTR_IMPL(SparseTensorProto, AttributeProto_AttributeType::AttributeProto_AttributeType_SPARSE_TENSOR, sparse_tensor) ADD_LIST_ATTR_IMPL(SparseTensorProto, AttributeProto_AttributeType::AttributeProto_AttributeType_SPARSE_TENSORS, sparse_tensors) -#endif #if !defined(ORT_MINIMAL_BUILD) bool Node::ClearAttribute(const std::string& attr_name) { @@ -978,6 +974,10 @@ Graph::Graph(const Model& owning_model, const gsl::not_null tensor{graph_proto_->add_initializer()}; auto status = utils::ConstantNodeProtoToTensorProto(node, *tensor); ORT_ENFORCE(status.IsOK(), status.ToString()); + if (node.attribute(0).type() == AttributeProto_AttributeType_SPARSE_TENSOR) { + auto p = sparse_tensor_names_.emplace(tensor->name()); + ORT_ENFORCE(p.second, "Duplicate constant node sparse initializer name: '", tensor->name(), "' Model is invalid."); + } } // Remove constant nodes as they're replaced with initializers above. @@ -989,6 +989,28 @@ Graph::Graph(const Model& owning_model, }), graph_mutable_nodes->end()); + // For now we convert sparse_intializer to dense tensors + // since there are currently no supported ops that consume sparse + // initializers directly. We remove them from graph_proto. We will reconstitute them + // when saving to ORT format to save space on disk. + if (graph_proto_->sparse_initializer_size() > 0) { + for (const auto& sparse_tensor : graph_proto_->sparse_initializer()) { + ORT_ENFORCE(utils::HasName(sparse_tensor), "Sparse initializer must have a name. This model is invalid"); + const gsl::not_null tensor{graph_proto_->add_initializer()}; + auto status = utils::SparseTensorProtoToDenseTensorProto(sparse_tensor, *tensor); + ORT_ENFORCE(status.IsOK(), status.ToString()); + auto p = sparse_tensor_names_.emplace(tensor->name()); + ORT_ENFORCE(p.second, "Duplicate sparse_tensor_initializer: '", tensor->name(), "' Model is invalid."); + } + + // Remove sparse_initializers from protobuf to save memory as they are converted to dense now + graph_proto_->mutable_sparse_initializer()->Clear(); + const int sparse_num_cleared = graph_proto_->sparse_initializer().ClearedCount(); + for (int i = 0; i < sparse_num_cleared; ++i) { + delete graph_proto_->mutable_sparse_initializer()->ReleaseCleared(); + } + } + // Collect all node arg name, type, shape information in the graph. // type/shape information will be assigned to each node arg when going // thru all nodes later. @@ -1012,7 +1034,13 @@ Graph::Graph(const Model& owning_model, // Copy initial tensors to a map. for (auto& tensor : graph_proto_->initializer()) { - name_to_initial_tensor_[tensor.name()] = &tensor; + auto p = name_to_initial_tensor_.emplace(tensor.name(), &tensor); + if (!p.second) { + LOGS(logger_, WARNING) << "Duplicate initializer (dense, sparse or ConstantNode): '" << tensor.name() + << "' the model will use the latest encountered initializer" + << ". Please, fix your model."; + p.first->second = &tensor; + } NodeArg* matching_graph_input = GetNodeArg(tensor.name()); TypeProto t{TypeProtoFromTensorProto(tensor)}; @@ -1602,7 +1630,7 @@ void Graph::KahnsTopologicalSort(const std::function& enter, topo_order.push_back(current->Index()); } - if (NumberOfNodes() != static_cast(topo_order.size())) { + if (NumberOfNodes() != static_cast(topo_order.size())) { ORT_THROW("Some nodes are not included in the topological sort, graph have a cycle."); } } @@ -1703,12 +1731,10 @@ bool FullyDefinedType(const TypeProto& type_proto) { auto& tensor_type = type_proto.tensor_type(); return utils::HasElemType(tensor_type); } -#if !defined(ORT_MINIMAL_BUILD) case TypeProto::kSparseTensorType: { auto& tensor_type = type_proto.sparse_tensor_type(); return utils::HasElemType(tensor_type); } -#endif case TypeProto::kSequenceType: { auto& seq_type = type_proto.sequence_type(); return utils::HasElemType(seq_type) && FullyDefinedType(seq_type.elem_type()); @@ -2565,8 +2591,11 @@ void Graph::RemoveInitializedTensor(const std::string& tensor_name) { auto iter = name_to_initial_tensor_.find(tensor_name); found = iter != name_to_initial_tensor_.end(); if (found) { - name_to_initial_tensor_.erase(tensor_name); + name_to_initial_tensor_.erase(iter); + sparse_tensor_names_.erase(tensor_name); SetGraphResolveNeeded(); + } else { + ORT_ENFORCE(sparse_tensor_names_.count(tensor_name) == 0, "sparse_tensor_names_ not in sync with name_to_initial_tensor_"); } auto& mutable_initializers = *(graph_proto_->mutable_initializer()); @@ -2632,6 +2661,7 @@ bool Graph::GetInitializedTensor(const std::string& tensor_name, const TensorPro void Graph::CleanAllInitializedTensors() noexcept { name_to_initial_tensor_.clear(); + sparse_tensor_names_.clear(); // Clearing RepeatedPtrFields does not free objects' memory. The memory is retained // and can be reused. Need to explicitly release the cleared objects and free the @@ -2766,15 +2796,30 @@ common::Status Graph::SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, auto inputs = SaveInputsOutputsToOrtFormat(builder, graph_inputs_including_initializers_); auto outputs = SaveInputsOutputsToOrtFormat(builder, graph_outputs_); + std::vector> sparse_initializers_data; + sparse_initializers_data.reserve(sparse_tensor_names_.size()); + const auto sparse_end = sparse_tensor_names_.end(); + std::vector> initializers_data; - initializers_data.reserve(name_to_initial_tensor_.size()); + assert(sparse_tensor_names_.size() <= name_to_initial_tensor_.size()); + initializers_data.reserve(name_to_initial_tensor_.size() - sparse_tensor_names_.size()); for (const auto& pair : name_to_initial_tensor_) { - flatbuffers::Offset fbs_tensor; - ORT_RETURN_IF_ERROR( - experimental::utils::SaveInitializerOrtFormat(builder, *pair.second, fbs_tensor)); - initializers_data.push_back(fbs_tensor); + if (sparse_tensor_names_.find(pair.first) == sparse_end) { + flatbuffers::Offset fbs_tensor; + ORT_RETURN_IF_ERROR( + experimental::utils::SaveInitializerOrtFormat(builder, *pair.second, fbs_tensor)); + initializers_data.push_back(fbs_tensor); + } else { + SparseTensorProto sparse_initializer; + ORT_RETURN_IF_ERROR(utils::DenseTensorToSparseTensorProto(*pair.second, sparse_initializer)); + flatbuffers::Offset fbs_sparse_tensor; + ORT_RETURN_IF_ERROR( + experimental::utils::SaveSparseInitializerOrtFormat(builder, sparse_initializer, fbs_sparse_tensor)); + sparse_initializers_data.push_back(fbs_sparse_tensor); + } } auto initializers = builder.CreateVector(initializers_data); + auto sparse_initializers = builder.CreateVector(sparse_initializers_data); std::vector> node_args_data; node_args_data.reserve(node_args_.size()); @@ -2808,6 +2853,7 @@ common::Status Graph::SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, gb.add_node_edges(node_edges); gb.add_inputs(inputs); gb.add_outputs(outputs); + gb.add_sparse_initializers(sparse_initializers); fbs_graph = gb.Finish(); return Status::OK(); } @@ -2932,13 +2978,30 @@ const ONNX_NAMESPACE::GraphProto& Graph::ToGraphProto() { } ONNX_NAMESPACE::GraphProto Graph::ToGraphProto() const { - if (!GraphProtoSyncNeeded()) { + if (!GraphProtoSyncNeeded() && sparse_tensor_names_.empty()) { return *graph_proto_; } + GraphProto result; ToGraphProtoInternal(result); - *result.mutable_initializer() = graph_proto_->initializer(); + // We want to make sure that sparse initializers do not appear + // as dense duplicates within the initializers list. + if (!sparse_tensor_names_.empty()) { + const auto sparse_end = sparse_tensor_names_.end(); + auto* mutable_initializer = result.mutable_initializer(); + for (const auto& initializer : graph_proto_->initializer()) { + if (sparse_end == sparse_tensor_names_.find(initializer.name())) { + *mutable_initializer->Add() = initializer; + } else { + auto& sparse_initializer = *result.add_sparse_initializer(); + auto status = utils::DenseTensorToSparseTensorProto(initializer, sparse_initializer); + ORT_ENFORCE(status.IsOK(), "Failed to convert dense initializer to sparse"); + } + } + } else { + *result.mutable_initializer() = graph_proto_->initializer(); + } return result; } @@ -3450,6 +3513,7 @@ Status Graph::LoadFromOrtFormat( // which will allow optimizers to run or non-ORT EPs to take nodes. // TODO: We could decide that an ORT model is load only even in a full build, // and in InferenceSession::Initialize skip partitioning and running optimizers. + graph->SetGraphResolveNeeded(); ORT_RETURN_IF_ERROR(graph->Resolve()); #else // probably nothing required here. validate with model that has nested subgraphs. @@ -3487,7 +3551,7 @@ Graph::Graph(const Model& owning_model, common::Status Graph::LoadFromOrtFormat(const onnxruntime::experimental::fbs::Graph& fbs_graph) { // We deserialize the graph from ORT format in the following order: - // 1. Deserialize the initializers + // 1. Deserialize the initializers and sparse initializers. Convert sparse to dense. // 2. Deserialize the NodeArgs // We need all NodeArg instances to exist when deserializing Nodes to setup the Node's // inputs/outputs/implicit inputs which are collections of NodeArg*. @@ -3498,13 +3562,45 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::experimental::fbs::Gr // Initializers auto fbs_initializers = fbs_graph.initializers(); + auto fbs_sparse_initializers = fbs_graph.sparse_initializers(); + flatbuffers::uoffset_t map_size = (fbs_initializers != nullptr ? fbs_initializers->size() : 0U) + + (fbs_sparse_initializers != nullptr ? fbs_sparse_initializers->size() : 0U); + + if (map_size > 0) { + name_to_initial_tensor_.reserve(map_size); + } + if (fbs_initializers) { - name_to_initial_tensor_.reserve(fbs_initializers->size()); for (const auto* fbs_tensor : *fbs_initializers) { ORT_RETURN_IF(nullptr == fbs_tensor, "Initializer tensor is missing. Invalid ORT format model."); TensorProto* initializer = deserialized_proto_data_.add_initializer(); ORT_RETURN_IF_ERROR(experimental::utils::LoadInitializerOrtFormat(*fbs_tensor, *initializer)); - name_to_initial_tensor_[initializer->name()] = initializer; + auto p = name_to_initial_tensor_.emplace(initializer->name(), initializer); + if (!p.second) { + LOGS(logger_, WARNING) << "Duplicate initializer (dense or ConstantNode): '" << initializer->name() + << "' the model will use the latest encountered initializer" + << ". Please, fix your model."; + p.first->second = initializer; + } + } + } + + if (fbs_sparse_initializers) { + sparse_tensor_names_.reserve(fbs_sparse_initializers->size()); + for (const auto* fbs_sparse_tensor : *fbs_sparse_initializers) { + ORT_RETURN_IF(nullptr == fbs_sparse_tensor, "Sparse Initializer tensor is missing. Invalid ORT format model."); + SparseTensorProto sparse_initializer; + ORT_RETURN_IF_ERROR(experimental::utils::LoadSparseInitializerOrtFormat(*fbs_sparse_tensor, sparse_initializer)); + TensorProto& initializer = *deserialized_proto_data_.add_initializer(); + ORT_RETURN_IF_ERROR(utils::SparseTensorProtoToDenseTensorProto(sparse_initializer, initializer)); + auto p = name_to_initial_tensor_.emplace(initializer.name(), &initializer); + if (!p.second) { + LOGS(logger_, WARNING) << "Duplicate initializer (dense, sparse or ConstantNode): '" << initializer.name() + << "' the model will use the latest encountered initializer" + << ". Please, fix your model."; + p.first->second = &initializer; + } + sparse_tensor_names_.emplace(initializer.name()); } } diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.cc b/onnxruntime/core/graph/graph_flatbuffers_utils.cc index d69c286190..7f2087383d 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.cc +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.cc @@ -18,19 +18,26 @@ namespace utils { #if !defined(ORT_MINIMAL_BUILD) +template +inline flatbuffers::Offset> +SaveDims(flatbuffers::FlatBufferBuilder& builder, const DimsFieldType& dims) { + std::vector dims_data(dims.size()); + std::copy(dims.cbegin(), dims.cend(), dims_data.begin()); + return builder.CreateVector(dims_data); +} + Status SaveInitializerOrtFormat(flatbuffers::FlatBufferBuilder& builder, const TensorProto& initializer, flatbuffers::Offset& fbs_tensor) { auto name = SaveStringToOrtFormat(builder, initializer.has_name(), initializer.name()); auto doc_string = SaveStringToOrtFormat(builder, initializer.has_doc_string(), initializer.doc_string()); - std::vector dims_data(initializer.dims().size()); - std::copy(initializer.dims().cbegin(), initializer.dims().cend(), dims_data.begin()); - auto dims = builder.CreateVector(dims_data); + auto dims = SaveDims(builder, initializer.dims()); + flatbuffers::Offset>> string_data; flatbuffers::Offset> raw_data; auto src_type = initializer.data_type(); - bool has_string_data = src_type == ONNX_NAMESPACE::TensorProto_DataType_STRING; + const bool has_string_data = src_type == ONNX_NAMESPACE::TensorProto_DataType_STRING; if (has_string_data) { std::vector string_data_vec(initializer.string_data().size()); std::copy(initializer.string_data().cbegin(), initializer.string_data().cend(), string_data_vec.begin()); @@ -56,6 +63,32 @@ Status SaveInitializerOrtFormat(flatbuffers::FlatBufferBuilder& builder, return Status::OK(); } +Status SaveSparseInitializerOrtFormat(flatbuffers::FlatBufferBuilder& builder, + const ONNX_NAMESPACE::SparseTensorProto& initializer, + flatbuffers::Offset& fbs_sparse_tensor) { + // values + const auto& values = initializer.values(); + flatbuffers::Offset values_off; + ORT_RETURN_IF_ERROR(SaveInitializerOrtFormat(builder, values, values_off)); + + // Indicies + const auto& indicies = initializer.indices(); + flatbuffers::Offset indicies_off; + ORT_RETURN_IF_ERROR(SaveInitializerOrtFormat(builder, indicies, indicies_off)); + + // Shape + auto shape = SaveDims(builder, initializer.dims()); + + fbs::SparseTensorBuilder stb(builder); + stb.add_values(values_off); + stb.add_indices(indicies_off); + stb.add_dims(shape); + + fbs_sparse_tensor = stb.Finish(); + + return Status::OK(); +} + #define GET_FBS_ATTR(BUILDER, TYPE, DATA_NAME, DATA) \ fbs::AttributeBuilder attr_builder(BUILDER); \ attr_builder.add_name(name); \ @@ -163,7 +196,7 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, mutable_str_data->Add(fbs_str->str()); } } else { - auto fbs_raw_data = fbs_tensor.raw_data(); + const auto* fbs_raw_data = fbs_tensor.raw_data(); ORT_RETURN_IF(nullptr == fbs_raw_data, "Missing raw data for initializer. Invalid ORT format model."); // fbs_raw_data is uint8_t vector, so the size is byte size @@ -173,6 +206,30 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, return Status::OK(); } +Status LoadSparseInitializerOrtFormat(const fbs::SparseTensor& fbs_sparse_tensor, + SparseTensorProto& initializer) { + SparseTensorProto loaded_initializer; + auto fbs_values_tensor = fbs_sparse_tensor.values(); + ORT_RETURN_IF(nullptr == fbs_values_tensor, "Missing values for sparse initializer. Invalid ORT format model."); + auto* values_tensor = loaded_initializer.mutable_values(); + ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_values_tensor, *values_tensor)); + ORT_RETURN_IF(values_tensor->name().empty(), "Missing name for SparseTensor initializer. Invalid ORT format model."); + + auto fbs_indicies_tensor = fbs_sparse_tensor.indices(); + ORT_RETURN_IF(nullptr == fbs_indicies_tensor, "Missing indicies for sparse initializer: ", "'", values_tensor->name(), "'", + "Invalid ORT format model."); + auto* indicies_tensor = loaded_initializer.mutable_indices(); + ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_indicies_tensor, *indicies_tensor)); + + auto fbs_dims = fbs_sparse_tensor.dims(); + ORT_RETURN_IF(nullptr == fbs_dims, "Missing dims for sparse initializer: ", "'", values_tensor->name(), "'", + "Invalid ORT format model."); + loaded_initializer.mutable_dims()->Add(fbs_dims->cbegin(), fbs_dims->cend()); + + swap(loaded_initializer, initializer); + return Status::OK(); +} + Status LoadAttributeOrtFormat(const fbs::Attribute& fbs_attr, ONNX_NAMESPACE::AttributeProto& attr_proto, std::unique_ptr& sub_graph, diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.h b/onnxruntime/core/graph/graph_flatbuffers_utils.h index 11d81bc134..fe3ae9334b 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.h +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.h @@ -5,6 +5,7 @@ namespace ONNX_NAMESPACE { class TensorProto; +class SparseTensorProto; class AttributeProto; } // namespace ONNX_NAMESPACE @@ -37,6 +38,10 @@ onnxruntime::common::Status SaveInitializerOrtFormat( flatbuffers::FlatBufferBuilder& builder, const ONNX_NAMESPACE::TensorProto& initializer, flatbuffers::Offset& fbs_tensor); +onnxruntime::common::Status SaveSparseInitializerOrtFormat( + flatbuffers::FlatBufferBuilder& builder, const ONNX_NAMESPACE::SparseTensorProto& initializer, + flatbuffers::Offset& fbs_sparse_tensor); + // Convert a given AttributeProto into fbs::Attribute // Note, we current do not support graphs, and sparse_tensor(s) // If the attribute type is a graph, we need to use the supplied graph, @@ -50,6 +55,9 @@ onnxruntime::common::Status SaveAttributeOrtFormat( onnxruntime::common::Status LoadInitializerOrtFormat( const fbs::Tensor& fbs_tensor, ONNX_NAMESPACE::TensorProto& initializer); +onnxruntime::common::Status LoadSparseInitializerOrtFormat(const fbs::SparseTensor& fbs_sparse_tensor, + ONNX_NAMESPACE::SparseTensorProto& initializer); + // Load a give fbs::Attribute into AttributeProto // Note, If the attribute type is a graph, we will leave an empty graph in attr_proto, // and set the deserialized Graph to the param graph diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 5a6dc164a3..9fcd078fd8 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -240,8 +240,15 @@ const Graph& Model::MainGraph() const noexcept { #if !defined(ORT_MINIMAL_BUILD) ModelProto Model::ToProto() { - *(model_proto_.mutable_graph()) = graph_->ToGraphProto(); - return model_proto_; + // We want to return back the original proto + // To that end invoke const overload of ToGraphProto() + // that returns by value and, therefore, allows us to filter + // out dense duplicates of sparse initializers and leave the original + // proto intact. + ModelProto result(model_proto_); + const auto& graph = *graph_; + *(result.mutable_graph()) = graph.ToGraphProto(); + return result; } Status Model::Load(std::istream& model_istream, ModelProto* p_model_proto) { diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 53d42475c1..53968de60e 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -244,7 +244,7 @@ class Model { // properties that would normally come from ModelProto std::string producer_version_; std::string producer_name_; - int64_t model_version_ = 0; + int64_t model_version_ = kNoVersion; int64_t ir_version_ = kNoVersion; std::string domain_; std::string doc_string_; diff --git a/onnxruntime/core/providers/cpu/element_wise_ranged_transform.h b/onnxruntime/core/providers/cpu/element_wise_ranged_transform.h index 0a219cbbc4..b7fe209203 100644 --- a/onnxruntime/core/providers/cpu/element_wise_ranged_transform.h +++ b/onnxruntime/core/providers/cpu/element_wise_ranged_transform.h @@ -109,4 +109,4 @@ class ElementWiseKernel final : public OpKernel { #define DEFINE_ELE_KERNEL(X) \ template \ using X = ElementWiseKernel>; -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ErrorHandling.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ErrorHandling.h index c7fda113d2..bf8b428c56 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ErrorHandling.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ErrorHandling.h @@ -8,6 +8,6 @@ auto _status = status; \ if (!_status.IsOK()) \ { \ - THROW_HR(StatusCodeToHRESULT(static_cast(_status.Code()))); \ + THROW_HR(StatusCodeToHRESULT(static_cast(_status.Code()))); \ } \ } while (0) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 40c7f47ee9..3777a97daa 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -102,15 +102,19 @@ std::atomic InferenceSession::global_session_id_{1}; // Only update this version when there is a file format change which will break the compatibilites // Once this model version is updated, the kSupportedOrtModelVersions in IsOrtModelVersionSupported // below will also need to be updated. -static constexpr const char* kOrtModelVersion = "1"; +// See onnxruntime/core/session/flatbuffers/schema/README.md for more details on versioning. +// Version 1 - history begins +// Version 2 - add serailization/deserialization of sparse_initializer +static constexpr const char* kOrtModelVersion = "2"; #if defined(ENABLE_ORT_FORMAT_LOAD) -// Check if the givne ort model version is supported in this build +// Check if the given ort model version is supported in this build static bool IsOrtModelVersionSupported(const std::string& ort_model_version) { // The ort model versions we will support in this build // This may contain more versions than the kOrtModelVersion, based on the compatibilities static const std::unordered_set kSupportedOrtModelVersions{ std::string("1.4.0"), // This is a special model version for existing converted model + std::string("1"), std::string(kOrtModelVersion), }; @@ -1340,20 +1344,20 @@ common::Status InferenceSession::ValidateInputs(const std::vector& ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(feed_name, input_shape, expected_shape)); } } else if (input_ml_value.IsSparseTensor()) { -#if !defined(ORT_MINIMAL_BUILD) if (!expected_type->IsSparseTensorType()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ", feed_name, " is not expected to be of type sparse tensor."); } auto expected_element_type = expected_type->AsSparseTensorType()->GetElementType(); - auto input_element_type = input_ml_value.Get().Values().DataType(); - // TODO: In the future, when sparsetensors are in use, find out how to properly verify the shape + const SparseTensor& sparse_tensor = input_ml_value.Get(); + auto input_element_type = sparse_tensor.Values().DataType(); ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_element_type, expected_element_type)); -#else - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name ", feed_name, - " is a sparse tensor, which is not supported in this build."); -#endif - + // Check shape + const auto& expected_shape = iter->second.tensor_shape; + if (expected_shape.NumDimensions() > 0) { + const auto& input_shape = sparse_tensor.Shape(); + ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(feed_name, input_shape, expected_shape)); + } } else if (input_ml_value.IsTensorSequence()) { if (!expected_type->IsTensorSequenceType()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ", feed_name, diff --git a/onnxruntime/test/framework/ort_model_only_test.cc b/onnxruntime/test/framework/ort_model_only_test.cc index 3d617a5f11..02c68d9c14 100644 --- a/onnxruntime/test/framework/ort_model_only_test.cc +++ b/onnxruntime/test/framework/ort_model_only_test.cc @@ -237,6 +237,7 @@ static void DumpOrtModelAsJson(const std::string& model_uri) { } */ + TEST(OrtModelOnlyTests, SerializeToOrtFormat) { const std::basic_string ort_file = ORT_TSTR("ort_github_issue_4031.onnx.ort"); SaveAndCompareModels("testdata/ort_github_issue_4031.onnx", ort_file); @@ -264,6 +265,26 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormat) { RunOrtModel(test_info); } +TEST(OrtModelOnlyTests, SparseInitializerHandling) { + const std::basic_string ort_file = ORT_TSTR("sparse_initializer_handling.onnx.ort"); + SaveAndCompareModels("testdata/sparse_initializer_handling.onnx", ort_file); + + SessionOptions so; + so.session_logid = "LoadOrtFormat"; + // not strictly necessary - type should be inferred from the filename, but to be sure we're testing what we + // think we're testing set it. + so.AddConfigEntry(kOrtSessionOptionsConfigLoadModelFormat, "ORT"); + InferenceSessionWrapper session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.Load(ort_file)); + ASSERT_STATUS_OK(session_object.Initialize()); + + // Check that there are no duplicates for initializers + const auto* init_list = session_object.GetOverridableInitializers().second; + ASSERT_EQ(init_list->size(), 1U); + const auto& init_def = *init_list->front(); + ASSERT_EQ(init_def.Name(), "x"); +} + #if !defined(DISABLE_ML_OPS) TEST(OrtModelOnlyTests, SerializeToOrtFormatMLOps) { const std::basic_string ort_file = ORT_TSTR("sklearn_bin_voting_classifier_soft_converted.ort"); @@ -308,6 +329,17 @@ TEST(OrtModelOnlyTests, SerializeToOrtFormatMLOps) { #endif // #if !defined(DISABLE_ML_OPS) #endif // #if !defined(ORT_MINIMAL_BUILD) +// test loading ORT format model with sparse initializers +TEST(OrtModelOnlyTests, LoadSparseInitializersOrtFormat) { + const std::basic_string ort_file = ORT_TSTR("testdata/sparse_initializer_handling.onnx.ort"); + SessionOptions so; + so.session_logid = "LoadOrtFormat"; + so.AddConfigEntry(kOrtSessionOptionsConfigLoadModelFormat, "ORT"); + InferenceSessionWrapper session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.Load(ort_file)); + ASSERT_STATUS_OK(session_object.Initialize()); +} + OrtModelTestInfo GetTestInfoForLoadOrtFormatModel() { OrtModelTestInfo test_info; test_info.model_filename = ORT_TSTR("testdata/ort_github_issue_4031.onnx.ort"); diff --git a/onnxruntime/test/framework/sparse_kernels_test.cc b/onnxruntime/test/framework/sparse_kernels_test.cc index 5a6ba846ab..3e72ba1a5d 100644 --- a/onnxruntime/test/framework/sparse_kernels_test.cc +++ b/onnxruntime/test/framework/sparse_kernels_test.cc @@ -310,7 +310,7 @@ class SparseTensorTests : public testing::Test { registerop(registry.get()); } - void BuildModel() { + void BuildModel() { IOnnxRuntimeOpSchemaRegistryList custom_schema_registries = {registry->GetOpschemaRegistry()}; model.reset(new Model("SparseTensorTest", false, ModelMetaData(), PathString(), custom_schema_registries, {}, {}, DefaultLoggingManager().DefaultLogger())); @@ -503,10 +503,19 @@ static std::vector CreateValues() { return {1, 2, 3, 4}; } +/* std::string suport in the future template <> std::vector CreateValues() { return {"one", "two", "three", "four"}; } +*/ + +/* BFloat16 support in the future +template <> +std::vector CreateValues() { + return {BFloat16(1.f), BFloat16(2.f), BFloat16(3.f), BFloat16(4.f)}; +} +*/ template static NodeProto CreateConstantNode(bool indices_1D, @@ -545,6 +554,7 @@ static NodeProto CreateConstantNode(bool indices_1D, 1, 2, 0}; } + indices_tp.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); indices_tp.mutable_int64_data()->Add(indices.cbegin(), indices.cend()); expected_data.resize(2 * 3 * 2); @@ -600,6 +610,22 @@ static void RawDataChecker(gsl::span expected, const TensorProto& actua EXPECT_THAT(actual_span, testing::ContainerEq(expected)); } +/* For BFloat16 support in the future. +template <> +void RawDataChecker(gsl::span expected_bfloat, const TensorProto& actual) { + int64_t actual_size = 1; + for (const auto dim : actual.dims()) { + actual_size *= dim; + } + + auto expected = expected_bfloat.as_span(); + const uint16_t* raw_data = reinterpret_cast(actual.raw_data().data()); + auto actual_span = gsl::make_span(raw_data, actual_size); + + EXPECT_THAT(actual_span, testing::ContainerEq(expected)); +} +*/ + TEST(SparseTensorConversionTests, TestConstantNodeConversion) { TestConversion( [](const std::vector& values, TensorProto& tp) { @@ -608,69 +634,145 @@ TEST(SparseTensorConversionTests, TestConstantNodeConversion) { }, RawDataChecker); - TestConversion( - [](const std::vector& values, TensorProto& tp) { - tp.set_data_type(TensorProto_DataType_INT32); + TestConversion( + [](const std::vector& values, TensorProto& tp) { + tp.set_data_type(TensorProto_DataType_INT8); tp.mutable_int32_data()->Add(values.cbegin(), values.cend()); }, - RawDataChecker); + RawDataChecker); - TestConversion( - [](const std::vector& values, TensorProto& tp) { - tp.set_data_type(TensorProto_DataType_INT64); - tp.mutable_int64_data()->Add(values.cbegin(), values.cend()); + TestConversion( + [](const std::vector& values, TensorProto& tp) { + RawDataWriter(values, tp, TensorProto_DataType_UINT8); }, - RawDataChecker); - - TestConversion( - [](const std::vector& values, TensorProto& tp) { - tp.set_data_type(TensorProto_DataType_DOUBLE); - tp.mutable_double_data()->Add(values.cbegin(), values.cend()); - }, - RawDataChecker); - - TestConversion( - [](const std::vector& values, TensorProto& tp) { - tp.set_data_type(TensorProto_DataType_UINT32); - tp.mutable_uint64_data()->Add(values.cbegin(), values.cend()); // stored in uint64_data despite being uint32_t - }, - RawDataChecker); - - TestConversion( - [](const std::vector& values, TensorProto& tp) { - tp.set_data_type(TensorProto_DataType_UINT64); - tp.mutable_uint64_data()->Add(values.cbegin(), values.cend()); - }, - RawDataChecker); - - // test a couple of types with values in raw data field - TestConversion( - [](const std::vector& values, TensorProto& tp) { - RawDataWriter(values, tp, TensorProto_DataType_FLOAT); - }, - RawDataChecker); - - TestConversion( - [](const std::vector& values, TensorProto& tp) { - RawDataWriter(values, tp, TensorProto_DataType_INT64); - }, - RawDataChecker); - - // strings can't use raw data, and string_data is a RepeatedPtrField (vs. RepeatedField for simple types) - // so has to be handled differently - TestConversion( - [](const std::vector& values, TensorProto& tp) { - tp.set_data_type(TensorProto_DataType_STRING); - for (auto cur = values.cbegin(), end = values.cend(); cur < end; ++cur) { - tp.mutable_string_data()->Add(std::string(*cur)); - } - }, - [](gsl::span expected, const TensorProto& actual) { - const auto& actual_strings = actual.string_data(); - for (int64_t i = 0, end = expected.size(); i < end; ++i) { - EXPECT_EQ(actual_strings[static_cast(i)], expected[i]); - } - }); + RawDataChecker); } + +/// Dense to Sparse conversion tests +#if !defined(ORT_MINIMAL_BUILD) + +template +static std::vector CreateSparseValues() { + return {0, 2, 3, 0}; +} + +/* std::string support in the future +template <> +std::vector CreateSparseValues() { + return {"", "two", "three", ""}; +} +*/ + +/* BFloat16 support in the future +template <> +std::vector CreateSparseValues() { + return {BFloat16(0.f), BFloat16(2.f), BFloat16(3.f), BFloat16(0.f)}; +} +*/ + +template +TensorProto CreateDenseTensor(std::function& values, TensorProto& tp)> inserter, + std::vector& expected_values, std::vector& expected_indicies) { + TensorProto result; + std::vector values = CreateSparseValues(); + expected_indicies = {1, 2}; + for (const auto& ind : expected_indicies) { + expected_values.push_back(values[ind]); + } + inserter(values, result); + result.add_dims(static_cast(values.size())); + return result; +} + +template +static void RawSparseDataChecker(gsl::span expected_values, + gsl::span expected_indicies, + const SparseTensorProto& actual) { + int64_t actual_size = 1; + for (const auto dim : actual.values().dims()) { + actual_size *= dim; + } + + const T* raw_data = reinterpret_cast(actual.values().raw_data().data()); + auto actual_span = gsl::make_span(raw_data, actual_size); + + EXPECT_THAT(actual_span, testing::ContainerEq(expected_values)); + + // Check indicies + EXPECT_THAT(actual.indices().data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); + auto actual_indicies = gsl::make_span(actual.indices().int64_data().data(), actual.indices().int64_data_size()); + EXPECT_THAT(actual_indicies, testing::ContainerEq(expected_indicies)); +} + +/* When we support BFloat16 +template <> +void RawSparseDataChecker(gsl::span expected_bfloat, + gsl::span expected_indicies, + const SparseTensorProto& actual) { + int64_t actual_size = 1; + for (const auto dim : actual.values().dims()) { + actual_size *= dim; + } + + static_assert(sizeof(uint16_t) == sizeof(BFloat16), "Expecting equal sizes"); + auto expected = expected_bfloat.as_span(); + const uint16_t* raw_data = reinterpret_cast(actual.values().raw_data().data()); + auto actual_span = gsl::make_span(raw_data, actual_size); + + EXPECT_THAT(actual_span, testing::ContainerEq(expected)); + // Check indicies + EXPECT_THAT(actual.indices().data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); + auto actual_indicies = gsl::make_span(actual.indices().int64_data().data(), actual.indices().int64_data_size()); + EXPECT_THAT(actual_indicies, testing::ContainerEq(expected_indicies)); +} +*/ + +template +static void TestDenseToSparseConversion( + std::function& values, TensorProto& tp)> inserter, + std::function expected, + gsl::span expected_indicies, + const SparseTensorProto& actual)> + checker) { + std::vector expected_values; + std::vector expected_indicies; + TensorProto dense_tensor = CreateDenseTensor(inserter, expected_values, expected_indicies); + + SparseTensorProto sparse_tensor; + utils::DenseTensorToSparseTensorProto(dense_tensor, sparse_tensor); + + gsl::span + expected_values_span = gsl::make_span(expected_values.data(), expected_values.size()); + gsl::span expected_ind_span = gsl::make_span(expected_indicies.data(), expected_indicies.size()); + checker(expected_values_span, expected_ind_span, sparse_tensor); +} + +TEST(SparseTensorConversionTests, TestDenseToSparseConversion) { + TestDenseToSparseConversion( + [](const std::vector& values, TensorProto& tp) { + tp.set_data_type(TensorProto_DataType_FLOAT); + tp.set_name("dense_float"); + tp.mutable_float_data()->Add(values.cbegin(), values.cend()); + }, + RawSparseDataChecker); + + TestDenseToSparseConversion( + [](const std::vector& values, TensorProto& tp) { + tp.set_name("dense_int8"); + tp.set_data_type(TensorProto_DataType_INT8); + tp.mutable_int32_data()->Add(values.cbegin(), values.cend()); + }, + RawSparseDataChecker); + + TestDenseToSparseConversion( + [](const std::vector& values, TensorProto& tp) { + tp.set_name("dense_int64"); + RawDataWriter(values, tp, TensorProto_DataType_UINT8); + }, + RawSparseDataChecker); +} + +#endif // !ORT_MINIMAL_BUILD + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index 307c0b0f49..2866f298fc 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -146,6 +146,73 @@ static void ConstructASimpleAddGraph(GraphProto& g, const char* domain) { SetTypeAndShape(output->mutable_type()->mutable_tensor_type(), 1, {3, 4, 5}); } +namespace sparse_details { +const std::vector shape = {3, 4, 5}; +const std::vector values = {13.f, + 17.f, + 19.f}; + +const std::vector indices = {9, 30, 50}; // Not to exceed 59 +} // namespace sparse_details + +// To match a simple Add graph above +static void ConstructSparseTensor(const std::string& name, + SparseTensorProto& sparse_proto) { + const std::vector& shape = sparse_details::shape; + const std::vector& values = sparse_details::values; + + auto& m_values = *sparse_proto.mutable_values(); + m_values.set_name(name); + m_values.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + *m_values.mutable_dims()->Add() = static_cast(values.size()); + std::string& raw_data = *m_values.mutable_raw_data(); + raw_data.resize(values.size() * sizeof(float)); + auto dest_span = gsl::make_span(reinterpret_cast(&raw_data[0]), values.size()); + std::copy(values.cbegin(), values.cend(), dest_span.begin()); + + const std::vector& indices = sparse_details::indices; // Not to exceed 59 + auto& m_indicies = *sparse_proto.mutable_indices(); + m_indicies.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + *m_indicies.mutable_dims()->Add() = static_cast(indices.size()); + auto* m_indicies_data = m_indicies.mutable_int64_data(); + m_indicies_data->Resize(static_cast(indices.size()), 0); + std::copy(indices.cbegin(), indices.cend(), m_indicies_data->begin()); + + auto& m_dims = *sparse_proto.mutable_dims(); + m_dims.Resize(static_cast(shape.size()), 0); + std::copy(shape.cbegin(), shape.cend(), m_dims.begin()); +} + +static void ValidateSparseTensorProto(const SparseTensorProto& proto) { + // check values. We always generate float + EXPECT_EQ(proto.values().data_type(), ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + EXPECT_EQ(proto.values().raw_data().size() % sizeof(float), 0U); + auto actual_values = gsl::make_span(reinterpret_cast(proto.values().raw_data().data()), + proto.values().raw_data().size() / sizeof(float)); + // Can't use ContainerEq on float + EXPECT_EQ(actual_values.size(), sparse_details::values.size()); + // std::equal() with a predicate is only in C++20 + auto actual_begin = actual_values.cbegin(); + const auto actual_end = actual_values.cend(); + auto expected_begin = sparse_details::values.cbegin(); + while (actual_begin != actual_end) { + auto diff = *actual_begin - *expected_begin; + EXPECT_TRUE(diff < std::numeric_limits::epsilon()) << "Actual :" << *actual_begin << " does not match expected: " << *expected_begin; + ++actual_begin; + ++expected_begin; + } + // Check indices + EXPECT_EQ(proto.indices().data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); + auto expected_indices = gsl::make_span(sparse_details::indices); + auto actual_indices = gsl::make_span(proto.indices().int64_data().data(), proto.indices().int64_data_size()); + EXPECT_THAT(actual_indices, testing::ContainerEq(expected_indices)); + // check shape + const auto& dims = proto.dims(); + auto actual_shape = gsl::make_span(dims.data(), dims.size()); + auto expected_shape = gsl::make_span(sparse_details::shape); + EXPECT_THAT(actual_shape, testing::ContainerEq(expected_shape)); +} + TEST_F(GraphTest, SimpleAddWithoutDomain) { ModelProto m; m.set_ir_version(3); @@ -1064,6 +1131,37 @@ TEST_F(GraphTest, UnusedInitializerIsIgnored) { ASSERT_TRUE(graph.GetAllInitializedTensors().empty()); } +TEST_F(GraphTest, UnusedSparseInitializerIsIgnored) { + std::string s1; + { + Model model("UnusedSparseInitializerIsIgnored", false, *logger_); + auto model_proto = model.ToProto(); + auto* m_graph = model_proto.mutable_graph(); + ConstructASimpleAddGraph(*m_graph, nullptr); + auto* m_sparse_initializer = m_graph->add_sparse_initializer(); + ConstructSparseTensor("unused_sparse_initializer", *m_sparse_initializer); + model_proto.SerializeToString(&s1); + } + + ModelProto model_proto_1; + const bool result = model_proto_1.ParseFromString(s1); + ASSERT_TRUE(result) << "Failed to load model from serialized protobuf"; + ASSERT_EQ(model_proto_1.graph().initializer_size(), 0); + ASSERT_EQ(model_proto_1.graph().sparse_initializer_size(), 1); + + std::shared_ptr p_tmp_model; + auto x = onnxruntime::Model::Load(model_proto_1, p_tmp_model, nullptr, *logger_); + ASSERT_STATUS_OK(x); + + auto& graph2 = p_tmp_model->MainGraph(); + EXPECT_STATUS_OK(graph2.Resolve()); + // Because the sparse initializer was unused, it was also removed + // from initializer as well as from sparse_initializer + ASSERT_TRUE(graph2.GetAllInitializedTensors().empty()); + auto& graph_proto = graph2.ToGraphProto(); + ASSERT_TRUE(graph_proto.sparse_initializer().empty()); +} + TEST_F(GraphTest, GraphConstruction_CheckIsNotAcyclic) { // A cyclic graph // SouceNode @@ -1527,6 +1625,57 @@ TEST_F(GraphTest, AddRemoveInitializerHandling) { << num_initializers << " remain."; } +TEST_F(GraphTest, SparseInitializerHandling) { + const char* const input_initializer_name = "x"; + Model model("SparseInitializerHandling", false, *logger_); + std::string s1; + // Create model proto with sparse initializer + { + auto model_proto = model.ToProto(); + auto* m_graph = model_proto.mutable_graph(); + ConstructASimpleAddGraph(*m_graph, nullptr); + auto* m_sparse_initializer = m_graph->add_sparse_initializer(); + ConstructSparseTensor(input_initializer_name, *m_sparse_initializer); + model_proto.SerializeToString(&s1); + } + + ModelProto model_proto_sparse; + const bool result = model_proto_sparse.ParseFromString(s1); + ASSERT_TRUE(result) << "Failed to load model from serialized protobuf"; + { + auto& graph_proto = model_proto_sparse.graph(); + ASSERT_EQ(graph_proto.initializer_size(), 0); + ASSERT_EQ(graph_proto.sparse_initializer_size(), 1); + ValidateSparseTensorProto(graph_proto.sparse_initializer().at(0)); + } + + std::shared_ptr p_tmp_model; + auto x = onnxruntime::Model::Load(model_proto_sparse, p_tmp_model, nullptr, *logger_); + + auto& graph2 = p_tmp_model->MainGraph(); + EXPECT_STATUS_OK(graph2.Resolve()); + // Sparse initializer got converted to dense and appears on the list of initializers + ASSERT_EQ(graph2.GetAllInitializedTensors().size(), 1U); + ASSERT_EQ(graph2.GetAllInitializedTensors().cbegin()->first.compare(input_initializer_name), 0); + + auto& graph_proto = graph2.ToGraphProto(); + // Got propagated to initializers list + ASSERT_EQ(graph_proto.initializer_size(), 1); + ASSERT_EQ(graph_proto.initializer().at(0).name().compare(input_initializer_name), 0); + // Got removed from sparse initializer list + ASSERT_EQ(graph_proto.sparse_initializer_size(), 0); + + { + // Check that Model::ToProto() does not return sparse among the normal initializers + // but reconstitutes sparse initializer from dense. Thus, here we have dense initializer list empty + // but it appears to be in the sparse. + auto model_proto_get = p_tmp_model->ToProto(); + ASSERT_EQ(model_proto_get.graph().initializer_size(), 0); + ASSERT_EQ(model_proto_get.graph().sparse_initializer_size(), 1); + ValidateSparseTensorProto(model_proto_get.graph().sparse_initializer().at(0)); + } +} + TEST_F(GraphTest, SetInputsAndSetOutputs_NewInputAndOutput) { std::shared_ptr model; { diff --git a/onnxruntime/test/testdata/sparse_initializer_handling.onnx b/onnxruntime/test/testdata/sparse_initializer_handling.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d3d1c5105d7fbe3fa96e9a58a4ef67c346bc9ef8 GIT binary patch literal 324 zcmZvXziz@X5XNmBL^!3D>$V~k)JjZD9y+qcQ`u0vb#oKbT8bUyguhSFN8k~7%w7hh zA~Af!eZRZ=?&99;L_q-E0o@B$uiX=wu&gM=@MEPzZRXxKR+si*!khL*Y5n&*`5fSY z5Wr(966nMX6`{H<*k1!*Lui*@Lx(9m#*EWRLdYT`L`88!}ly+4F zgePhM_j%={vw5iMYS~Di|7>$k)eDR+fh(6rH|x+|Ut_~9fVQ&oekWbYJL&=Wxo(2C Z9-qUmF{EdA70kw(t^7yt;N~KP(?8_gNf!VB literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/sparse_initializer_handling.onnx.ort b/onnxruntime/test/testdata/sparse_initializer_handling.onnx.ort new file mode 100644 index 0000000000000000000000000000000000000000..fa3b88f45ff8b16d0bc2191b0918f7ae4e9097f0 GIT binary patch literal 1488 zcmZ`(J8KkC6h5n?j;n@6gDfavibYZ^!y-Yp2$DrDEW|+~Y`f0Rnt{xW%s#TFuuf?u zDbi{sVv$mUg@t9XvJwSs!(jd=9 zaW~4DjBROkyIjr*{Y@0xh2|Nyi=yBh-=XlQBfW-M`iM1qGNy+$HFH-151{?b_}@dr zEMfdgUKsyp<7X`8jqx8EKkI4Xe96L5^W|y{d>$EYFiUN$EuUYe&T&2qh+(e)nmLxw zHRInfe#_@3Y~H0SmvINKV!*s9FIZr-FvE@SC7boF6;TbI7SEfJf z@{X4QayXduugvwgyuVELFV5{-{vxN*$HA|DlP=!e|IR@!sqN<*t)F4uY=jP4lQzI4G+*?c+e8r0d`y5 zv-w<;D>XUN-ybIG+h85~xcAg(#*$+V&L7bGz;mDrGyru@%$hk?=(_-fS>h*gm}mX0 zH-K&5A>$ZBzMa_5(sb-U$)iymZTW+|6Q^J z{;Q7Fxq|&UZho@ai#mg09&bl^A9LluBwwPQ1(d3L>7Tqb+}H{c#~ab{diKmPv(Za6 HQ{LxadEmbB literal 0 HcmV?d00001 diff --git a/orttraining/orttraining/models/gpt2/main.cc b/orttraining/orttraining/models/gpt2/main.cc index 70610cf1d3..4c418ce6fa 100644 --- a/orttraining/orttraining/models/gpt2/main.cc +++ b/orttraining/orttraining/models/gpt2/main.cc @@ -239,7 +239,7 @@ Status ParseArguments(int argc, char* argv[], GPT2Parameters& params, OrtParamet int64_t seed = flags["seed"].as(); if (params.horizontal_parallel_size > 1 && seed <= 0) { - seed = 8211; // Megatron needs a random seed. + seed = 8211; // Megatron needs a random seed. } if (seed > 0) { utils::SetRandomSeed(seed); @@ -279,7 +279,7 @@ float GetLossValue(const Tensor& loss_tensor) { // mapping to define what to be stored in mapped_dimensions // see GetTensorDimensionsFromInputs() in training_util.h and training_runner.cc for more details const std::map> input_to_dimension_mapping = { - {"input_ids", {"SeqLen", 0}}, // int64[batch,seqlen] "seqlen" -> "SeqLen", 0 + {"input_ids", {"SeqLen", 0}}, // int64[batch,seqlen] "seqlen" -> "SeqLen", 0 }; // generic properties for storing perf metrics diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 12d2a97650..291dad6acd 100755 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1101,8 +1101,8 @@ def run_android_tests(args, source_dir, config, cwd): # For Android arm64 abi we are only verify the size of the binary generated by minimal build config # Will fail the build if the shared_lib size is larger than the threshold if args.minimal_build and config == 'MinSizeRel' and args.build_shared_lib and args.test_binary_size: - # set current size limit to 1100KB - bin_size_threshold = 1100000 + # set current size limit to 1165KB which is 110K large than 1.5.2 release. + bin_size_threshold = 1165000 bin_actual_size = os.path.getsize(os.path.join(cwd, 'libonnxruntime.so')) log.info('Android arm64 minsizerel libonnxruntime.so size [' + str(bin_actual_size) + 'B]') # Write the binary size to a file for uploading later