onnxruntime/onnxruntime/core/framework/tensorprotoutils.h
Dmitri Smirnov 3433576fd3
Support for Sparse Initializers (#5540)
Introduce sparse_initializers support.
  Convert them to dense on model load and prune graph_proto_
  so they don't consume space. Convert back to sparse on ORT Format model save.
  Implement serializing sparse initializers to OrtFormat.
  Fix Model::ToProto() to return original sparse initializers
  Set a flag that graph_sync is needed when loading a simple ORT Format model.
  otherwise nothing is resolved.
  Add ORT Format history to README.md
  ifdef MINIMAL build for DenseToSparseTensorInitializer
  Allow duplicate initializers to support existing models.
  Issue a warning instead of aborting.

* Revert "Remove SparseTensor support from minimal build. (#5114)"
This reverts commit 59ee8ffb17.



Signed-off-by: Dmitri Smirnov <dmitrism@microsoft.com>
2020-10-27 10:32:06 -07:00

243 lines
9.8 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <vector>
#include <type_traits>
#include "core/common/common.h"
#include "core/common/status.h"
#include "core/framework/allocator.h"
#include "core/framework/ml_value.h"
#include "core/framework/mem_buffer.h"
#include "core/framework/tensor_external_data_info.h"
#include "core/graph/onnx_protobuf.h"
#include "core/platform/env.h"
namespace ONNX_NAMESPACE {
class TensorProto;
class TensorShapeProto;
/** Test if two TensorShapeProto dimensions are equal. */
bool operator==(const TensorShapeProto_Dimension& l, const TensorShapeProto_Dimension& r);
bool operator!=(const TensorShapeProto_Dimension& l, const TensorShapeProto_Dimension& r);
} // namespace ONNX_NAMESPACE
namespace onnxruntime {
class Tensor;
namespace utils {
TensorShape GetTensorShapeFromTensorShapeProto(const ONNX_NAMESPACE::TensorShapeProto& tensor_shape_proto);
/**
* deserialize a TensorProto into a preallocated memory buffer.
* \param tensor_proto_path A local file path of where the 'input' was loaded from. Can be NULL if the tensor proto doesn't
* have any external data or it was loaded from current working dir. This path could be either a
* relative path or an absolute path.
*/
common::Status TensorProtoToMLValue(const Env& env, const ORTCHAR_T* tensor_proto_path,
const ONNX_NAMESPACE::TensorProto& input, const MemBuffer& m, OrtValue& value,
OrtCallback& deleter);
/** Creates a TensorProto from a Tensor.
@param[in] tensor the Tensor whose data and shape will be used to create the TensorProto.
@param[in] tensor_proto_name the name of the TensorProto.
@return the TensorProto.
Note: Method currently requires that data is in little-endian format.
*/
ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std::string& tensor_proto_name);
ONNXTensorElementDataType CApiElementTypeFromProtoType(int type);
ONNXTensorElementDataType GetTensorElementType(const ONNX_NAMESPACE::TensorProto& tensor_proto);
// How much memory it will need for putting the content of this tensor into a plain array
// complex64/complex128 tensors are not supported.
// The output value could be zero or -1.
template <size_t alignment>
common::Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out);
template <typename T>
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len,
/*out*/ T* p_data, size_t expected_size);
// Convert the NodeProto from a Constant node into a TensorProto that can be used as an initializer
common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& node,
ONNX_NAMESPACE::TensorProto& tensor);
// Convert a SparseTensorProto to a dense TensorProto
common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseTensorProto& sparse,
ONNX_NAMESPACE::TensorProto& dense);
#if !defined(ORT_MINIMAL_BUILD)
common::Status DenseTensorToSparseTensorProto(const ONNX_NAMESPACE::TensorProto& dense,
ONNX_NAMESPACE::SparseTensorProto& sparse);
#endif // !ORT_MINIMAL_BUILD
inline bool HasDimValue(const ONNX_NAMESPACE::TensorShapeProto_Dimension& dim) {
return dim.value_case() == ONNX_NAMESPACE::TensorShapeProto_Dimension::kDimValue;
}
inline bool HasDimParam(const ONNX_NAMESPACE::TensorShapeProto_Dimension& dim) {
return dim.value_case() == ONNX_NAMESPACE::TensorShapeProto_Dimension::kDimParam;
}
inline bool HasTensorType(const ONNX_NAMESPACE::TypeProto& type_proto) {
return type_proto.value_case() == ONNX_NAMESPACE::TypeProto::kTensorType;
}
inline bool HasElemType(const ONNX_NAMESPACE::TypeProto_Tensor& ten_proto) {
return ten_proto.elem_type() != ONNX_NAMESPACE::TensorProto::UNDEFINED;
;
}
inline bool HasShape(const ONNX_NAMESPACE::TypeProto_Tensor& ten_proto) {
// XXX: Figure out how do in proto3
return ten_proto.has_shape();
}
inline bool HasShape(const ONNX_NAMESPACE::TypeProto_SparseTensor& ten_proto) {
// XXX: Figure out how do in proto3
return ten_proto.has_shape();
}
inline bool HasRawData(const ONNX_NAMESPACE::TensorProto& ten_proto) {
// Can not be UNDEFINED and can not be STRING but test for STRING is usually performed separately
// to return an error
return ten_proto.data_type() != ONNX_NAMESPACE::TensorProto::UNDEFINED &&
ten_proto.has_raw_data(); // XXX: Figure out how to do in proto3
}
inline bool HasDataType(const ONNX_NAMESPACE::TensorProto& ten_proto) {
return ten_proto.data_type() != ONNX_NAMESPACE::TensorProto::UNDEFINED;
}
inline bool HasName(const ONNX_NAMESPACE::TensorProto& ten_proto) {
return ten_proto.has_name(); // XXX
}
inline bool HasElemType(const ONNX_NAMESPACE::TypeProto_Sequence& seq_proto) {
return seq_proto.elem_type().value_case() != ONNX_NAMESPACE::TypeProto::VALUE_NOT_SET;
}
inline bool HasElemType(const ONNX_NAMESPACE::TypeProto_SparseTensor& ten_proto) {
return ten_proto.elem_type() != ONNX_NAMESPACE::TensorProto::UNDEFINED;
}
inline bool HasName(const ONNX_NAMESPACE::SparseTensorProto& ten_proto) {
return ten_proto.values().has_name(); // XXX
}
inline bool HasKeyType(const ONNX_NAMESPACE::TypeProto_Map& map_proto) {
return map_proto.key_type() != ONNX_NAMESPACE::TensorProto::UNDEFINED;
}
inline bool HasValueType(const ONNX_NAMESPACE::TypeProto_Map& map_proto) {
return map_proto.value_type().value_case() != ONNX_NAMESPACE::TypeProto::VALUE_NOT_SET;
}
inline bool HasType(const ONNX_NAMESPACE::ValueInfoProto& vi_proto) {
return vi_proto.type().value_case() != ONNX_NAMESPACE::TypeProto::VALUE_NOT_SET;
}
inline bool HasName(const ONNX_NAMESPACE::ValueInfoProto& vi_proto) {
return vi_proto.has_name(); // XXX: Figure out proto3 way
}
inline bool HasDomain(const ONNX_NAMESPACE::TypeProto_Opaque& op_proto) {
return !op_proto.domain().empty();
}
inline bool HasName(const ONNX_NAMESPACE::TypeProto_Opaque& op_proto) {
return !op_proto.name().empty();
}
inline bool HasType(const ONNX_NAMESPACE::AttributeProto& at_proto) {
return at_proto.type() != ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_UNDEFINED;
}
inline bool HasFloat(const ONNX_NAMESPACE::AttributeProto& at_proto) {
return at_proto.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_FLOAT;
}
inline bool HasFloats(const ONNX_NAMESPACE::AttributeProto& at_proto) {
return at_proto.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_FLOATS;
}
inline bool HasInt(const ONNX_NAMESPACE::AttributeProto& at_proto) {
return at_proto.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_INT;
}
inline bool HasInts(const ONNX_NAMESPACE::AttributeProto& at_proto) {
return at_proto.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_INTS;
}
inline bool HasString(const ONNX_NAMESPACE::AttributeProto& at_proto) {
return at_proto.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_STRING;
}
inline bool HasStrings(const ONNX_NAMESPACE::AttributeProto& at_proto) {
return at_proto.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_STRINGS;
}
inline bool HasTensor(const ONNX_NAMESPACE::AttributeProto& at_proto) {
return at_proto.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_TENSOR;
}
inline bool HasTensors(const ONNX_NAMESPACE::AttributeProto& at_proto) {
return at_proto.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_TENSORS;
}
inline bool HasGraph(const ONNX_NAMESPACE::AttributeProto& at_proto) {
return at_proto.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_GRAPH;
}
inline bool HasGraphs(const ONNX_NAMESPACE::AttributeProto& at_proto) {
return at_proto.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_GRAPHS;
}
inline bool HasName(const ONNX_NAMESPACE::AttributeProto& at_proto) {
return at_proto.has_name(); // XXX: Fugure out proto3
}
inline bool HasGraph(const ONNX_NAMESPACE::ModelProto& m_proto) {
return m_proto.has_graph(); // XXX proto3
}
inline bool HasIrVersion(const ONNX_NAMESPACE::ModelProto& m_proto) {
return m_proto.has_ir_version(); // XXX proto3
}
inline bool HasModelVersion(const ONNX_NAMESPACE::ModelProto& m_proto) {
return m_proto.has_model_version(); // XXX proto3
}
inline bool HasName(const ONNX_NAMESPACE::NodeProto& node_proto) {
//XXX: Figure out proto3 style
return node_proto.has_name();
}
// UnpackTensor from either raw data or the type specific data field.
template <typename T>
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, /*out*/ T* p_data, size_t expected_size) {
return HasRawData(tensor)
? UnpackTensor(tensor, tensor.raw_data().data(), tensor.raw_data().size(), p_data, expected_size)
: UnpackTensor(tensor, nullptr, 0, p_data, expected_size);
}
/**
* Unpack the data from an initializer tensor
* Please note, this function does not unpack string_data of an initializer tensor
* @param initializer given initializer tensor
* @param unpacked_tensor the data from the initaizlier in uint8_t* form
* @param tensor_byte_size the byte size of the unpacked_tensor
* @returns Status::OK() if data is unpacked successfully
*/
common::Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& initializer,
std::unique_ptr<uint8_t[]>& unpacked_tensor,
size_t& tensor_byte_size) ORT_MUST_USE_RESULT;
} // namespace utils
} // namespace onnxruntime