mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-27 22:45:57 +00:00
Constant-12 support (#3304)
1. Support the new fields for Constant in opset 12 2. Support SparseTensor in the Constant node by converting to dense tensor when lifting the Constant to an initializer. Will make a model with a sparse tensor in a Constant work but isn't an overly efficient approach.
This commit is contained in:
parent
2332a93db0
commit
ace741680d
5 changed files with 570 additions and 54 deletions
|
|
@ -59,28 +59,28 @@ namespace onnxruntime {
|
|||
namespace utils {
|
||||
|
||||
// This macro doesn't work for Float16/bool/string tensors
|
||||
#define DEFINE_UNPACK_TENSOR(T, Type, field_name, field_size) \
|
||||
template <> \
|
||||
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, \
|
||||
/*out*/ T* p_data, size_t expected_size) { \
|
||||
if (nullptr == p_data) { \
|
||||
const size_t size = raw_data != nullptr ? raw_data_len : tensor.field_size(); \
|
||||
if (size == 0) return Status::OK(); \
|
||||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \
|
||||
} \
|
||||
if (nullptr == p_data || Type != tensor.data_type()) { \
|
||||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \
|
||||
} \
|
||||
if (raw_data != nullptr) { \
|
||||
return UnpackTensorWithRawData(raw_data, raw_data_len, expected_size, p_data); \
|
||||
} \
|
||||
if (static_cast<size_t>(tensor.field_size()) != expected_size) \
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "corrupted protobuf data: tensor shape size(", expected_size, \
|
||||
") does not match the data size(", tensor.field_size(), ") in proto"); \
|
||||
auto& data = tensor.field_name(); \
|
||||
for (auto data_iter = data.cbegin(); data_iter != data.cend(); ++data_iter) \
|
||||
*p_data++ = *reinterpret_cast<const T*>(data_iter); \
|
||||
return Status::OK(); \
|
||||
#define DEFINE_UNPACK_TENSOR(T, Type, field_name, field_size) \
|
||||
template <> \
|
||||
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, \
|
||||
/*out*/ T* p_data, size_t expected_size) { \
|
||||
if (nullptr == p_data) { \
|
||||
const size_t size = raw_data != nullptr ? raw_data_len : tensor.field_size(); \
|
||||
if (size == 0) return Status::OK(); \
|
||||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \
|
||||
} \
|
||||
if (nullptr == p_data || Type != tensor.data_type()) { \
|
||||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \
|
||||
} \
|
||||
if (raw_data != nullptr) { \
|
||||
return UnpackTensorWithRawData(raw_data, raw_data_len, expected_size, p_data); \
|
||||
} \
|
||||
if (static_cast<size_t>(tensor.field_size()) != expected_size) \
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "corrupted protobuf data: tensor shape size(", \
|
||||
expected_size, ") does not match the data size(", tensor.field_size(), ") in proto"); \
|
||||
auto& data = tensor.field_name(); \
|
||||
for (auto data_iter = data.cbegin(); data_iter != data.cend(); ++data_iter) \
|
||||
*p_data++ = *reinterpret_cast<const T*>(data_iter); \
|
||||
return Status::OK(); \
|
||||
}
|
||||
|
||||
// TODO: complex64 complex128
|
||||
|
|
@ -310,10 +310,10 @@ ORT_API(void, OrtUninitializeBuffer, _In_opt_ void* input, size_t input_len, enu
|
|||
}
|
||||
}
|
||||
|
||||
#define CASE_PROTO(X, Y) \
|
||||
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \
|
||||
ORT_RETURN_IF_ERROR( \
|
||||
::onnxruntime::utils::UnpackTensor<Y>(tensor_proto, raw_data, raw_data_len, (Y*)preallocated, static_cast<size_t>(tensor_size))); \
|
||||
#define CASE_PROTO(X, Y) \
|
||||
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \
|
||||
ORT_RETURN_IF_ERROR( \
|
||||
UnpackTensor<Y>(tensor_proto, raw_data, raw_data_len, (Y*)preallocated, static_cast<size_t>(tensor_size))); \
|
||||
break;
|
||||
|
||||
class AutoDelete {
|
||||
|
|
@ -421,7 +421,8 @@ Status TensorProtoToMLValue(const Env& env, const ORTCHAR_T* tensor_proto_path,
|
|||
int64_t tensor_size = 1;
|
||||
{
|
||||
for (auto i : tensor_proto.dims()) {
|
||||
if (i < 0) return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "tensor can't contain negative dims");
|
||||
if (i < 0)
|
||||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "tensor can't contain negative dims");
|
||||
tensor_size *= i;
|
||||
}
|
||||
}
|
||||
|
|
@ -464,8 +465,8 @@ Status TensorProtoToMLValue(const Env& env, const ORTCHAR_T* tensor_proto_path,
|
|||
deleter.f = UnInitTensor;
|
||||
deleter.param = new UnInitializeParam{preallocated, preallocated_size, ele_type};
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(::onnxruntime::utils::UnpackTensor<std::string>(tensor_proto, raw_data, raw_data_len,
|
||||
(std::string*)preallocated, static_cast<size_t>(tensor_size)));
|
||||
ORT_RETURN_IF_ERROR(UnpackTensor<std::string>(tensor_proto, raw_data, raw_data_len,
|
||||
(std::string*)preallocated, static_cast<size_t>(tensor_size)));
|
||||
break;
|
||||
default: {
|
||||
std::ostringstream ostr;
|
||||
|
|
@ -536,12 +537,238 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std:
|
|||
//tensor_proto.set_data_type(utils::GetTensorProtoType(tensor));
|
||||
|
||||
tensor_proto.set_data_type(tensor_proto_type.tensor_type().elem_type());
|
||||
|
||||
tensor_proto.set_raw_data(tensor.DataRaw(), tensor.SizeInBytes());
|
||||
|
||||
return tensor_proto;
|
||||
}
|
||||
|
||||
common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& node,
|
||||
ONNX_NAMESPACE::TensorProto& tensor) {
|
||||
const AttributeProto& constant_attribute = node.attribute(0);
|
||||
|
||||
switch (constant_attribute.type()) {
|
||||
case AttributeProto_AttributeType_TENSOR:
|
||||
tensor = constant_attribute.t();
|
||||
break;
|
||||
case AttributeProto_AttributeType_FLOAT:
|
||||
tensor.set_data_type(TensorProto_DataType_FLOAT);
|
||||
tensor.add_float_data(constant_attribute.f());
|
||||
break;
|
||||
case AttributeProto_AttributeType_FLOATS:
|
||||
tensor.set_data_type(TensorProto_DataType_FLOAT);
|
||||
*tensor.mutable_float_data() = constant_attribute.floats();
|
||||
break;
|
||||
case AttributeProto_AttributeType_INT:
|
||||
tensor.set_data_type(TensorProto_DataType_INT64);
|
||||
tensor.add_int64_data(constant_attribute.i());
|
||||
break;
|
||||
case AttributeProto_AttributeType_INTS:
|
||||
tensor.set_data_type(TensorProto_DataType_INT64);
|
||||
*tensor.mutable_int64_data() = constant_attribute.ints();
|
||||
break;
|
||||
case AttributeProto_AttributeType_STRING:
|
||||
tensor.set_data_type(TensorProto_DataType_STRING);
|
||||
tensor.add_string_data(constant_attribute.s());
|
||||
break;
|
||||
case AttributeProto_AttributeType_STRINGS: {
|
||||
tensor.set_data_type(TensorProto_DataType_STRING);
|
||||
*tensor.mutable_string_data() = constant_attribute.strings();
|
||||
break;
|
||||
}
|
||||
case AttributeProto_AttributeType_SPARSE_TENSOR: {
|
||||
auto& s = constant_attribute.sparse_tensor();
|
||||
ORT_RETURN_IF_ERROR(SparseTensorProtoToDenseTensorProto(s, tensor));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
ORT_THROW("Unsupported attribute value type of ", constant_attribute.type(),
|
||||
" in 'Constant' node '", node.name(), "'");
|
||||
}
|
||||
|
||||
// set name last in case attribute type was tensor (would copy over name)
|
||||
*(tensor.mutable_name()) = node.output(0);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static Status CopySparseData(size_t n_sparse_elements,
|
||||
const ONNX_NAMESPACE::TensorProto& indices,
|
||||
gsl::span<const int64_t> dims,
|
||||
std::function<void(size_t from_idx, size_t to_idx)> copier) {
|
||||
Status status = Status::OK();
|
||||
TensorShape indices_shape(indices.dims().data(), indices.dims().size());
|
||||
|
||||
auto indices_data = gsl::make_span<const int64_t>(indices.int64_data().data(), static_cast<size_t>(indices_shape.Size()));
|
||||
|
||||
if (indices_shape.NumDimensions() == 1) {
|
||||
// flattened indexes
|
||||
for (size_t i = 0; i < n_sparse_elements; ++i) {
|
||||
copier(i, static_cast<size_t>(indices_data[i]));
|
||||
}
|
||||
} else if (indices_shape.NumDimensions() == 2) {
|
||||
// entries in format {NNZ, rank}
|
||||
size_t rank = static_cast<size_t>(indices_shape[1]);
|
||||
ORT_ENFORCE(rank == dims.size() && rank > 0);
|
||||
const int64_t* cur_index = indices_data.data();
|
||||
std::vector<size_t> multipliers;
|
||||
multipliers.resize(rank);
|
||||
|
||||
// calculate sum of inner dimension elements for each dimension.
|
||||
// e.g. if shape {2,3,4}, the result should be {3*4, 4, 1}
|
||||
multipliers[rank - 1] = 1;
|
||||
for (int32_t r = static_cast<int32_t>(rank) - 2; r >= 0; --r) {
|
||||
multipliers[r] = static_cast<size_t>(dims[r + 1]) * multipliers[r + 1];
|
||||
}
|
||||
|
||||
// calculate the offset for the entry
|
||||
// e.g. if shape was {2,3,4} and entry was (1, 0, 2) the offset is 14
|
||||
// as there are 2 rows, each with 12 entries per row
|
||||
for (size_t i = 0; i < n_sparse_elements; ++i) {
|
||||
size_t idx = 0;
|
||||
for (size_t j = 0; j < rank; ++j) {
|
||||
idx += static_cast<size_t>(cur_index[j]) * multipliers[j];
|
||||
}
|
||||
|
||||
copier(i, idx);
|
||||
cur_index += rank;
|
||||
}
|
||||
|
||||
ORT_ENFORCE(cur_index == &*indices_data.cend());
|
||||
} else {
|
||||
status = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Invalid SparseTensor indices. Should be rank 0 or 1. Got:",
|
||||
indices_shape);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseTensorProto& sparse,
|
||||
ONNX_NAMESPACE::TensorProto& dense) {
|
||||
Status status = Status::OK();
|
||||
|
||||
const auto& sparse_values = sparse.values();
|
||||
auto type = sparse_values.data_type();
|
||||
dense.set_data_type(type);
|
||||
|
||||
SafeInt<size_t> n_sparse_elements = 1;
|
||||
for (auto dim : sparse_values.dims()) {
|
||||
n_sparse_elements *= dim;
|
||||
}
|
||||
|
||||
SafeInt<size_t> n_dense_elements = 1;
|
||||
for (auto dim : sparse.dims()) {
|
||||
n_dense_elements *= dim;
|
||||
dense.add_dims(dim);
|
||||
}
|
||||
|
||||
const auto& indices = sparse.indices();
|
||||
auto dims = gsl::make_span<const int64_t>(dense.dims().data(), dense.dims().size());
|
||||
|
||||
// need to read in sparse data first as it could be in a type specific field, in raw data, or in external data
|
||||
size_t sparse_bytes;
|
||||
ORT_RETURN_IF_ERROR(GetSizeInBytesFromTensorProto<0>(sparse_values, &sparse_bytes));
|
||||
|
||||
if (type != TensorProto_DataType_STRING) {
|
||||
std::vector<unsigned char> sparse_data_storage(sparse_bytes, 0);
|
||||
void* sparse_data = sparse_data_storage.data();
|
||||
|
||||
size_t element_size = 0;
|
||||
|
||||
// setup buffer for output
|
||||
switch (type) {
|
||||
case TensorProto_DataType_FLOAT: {
|
||||
element_size = sizeof(float);
|
||||
UnpackTensor<float>(sparse_values, static_cast<float*>(sparse_data), n_sparse_elements);
|
||||
break;
|
||||
}
|
||||
case TensorProto_DataType_INT64: {
|
||||
element_size = sizeof(int64_t);
|
||||
UnpackTensor<int64_t>(sparse_values, static_cast<int64_t*>(sparse_data), n_sparse_elements);
|
||||
break;
|
||||
}
|
||||
case TensorProto_DataType_INT32: {
|
||||
element_size = sizeof(int32_t);
|
||||
UnpackTensor<int32_t>(sparse_values, static_cast<int32_t*>(sparse_data), n_sparse_elements);
|
||||
break;
|
||||
}
|
||||
case TensorProto_DataType_DOUBLE: {
|
||||
element_size = sizeof(double);
|
||||
UnpackTensor<double>(sparse_values, static_cast<double*>(sparse_data), n_sparse_elements);
|
||||
break;
|
||||
}
|
||||
case TensorProto_DataType_UINT32: {
|
||||
element_size = sizeof(uint32_t);
|
||||
UnpackTensor<uint32_t>(sparse_values, static_cast<uint32_t*>(sparse_data), n_sparse_elements);
|
||||
break;
|
||||
}
|
||||
case TensorProto_DataType_UINT64: {
|
||||
element_size = sizeof(uint64_t);
|
||||
UnpackTensor<uint64_t>(sparse_values, static_cast<uint64_t*>(sparse_data), n_sparse_elements);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported sparse tensor data type of ", type);
|
||||
}
|
||||
|
||||
// by putting the data into a std::string we can avoid a copy as set_raw_data can do a std::move
|
||||
// into the TensorProto. however to actually write to the buffer we have created in the std::string we need
|
||||
// this somewhat dirty hack to get a mutable pointer. we could alternatively use &dense_data_storage.front()
|
||||
// but using const_cast makes it more obvious we're doing something ugly.
|
||||
std::string dense_data_storage(n_dense_elements * element_size, 0);
|
||||
void* dense_data = const_cast<char*>(dense_data_storage.data());
|
||||
|
||||
switch (element_size) {
|
||||
case 4: {
|
||||
auto dense_data_span = gsl::make_span<uint32_t>(static_cast<uint32_t*>(dense_data), n_dense_elements);
|
||||
status = CopySparseData<uint32_t>(
|
||||
n_sparse_elements,
|
||||
indices, dims,
|
||||
[sparse_data, dense_data_span](size_t from_idx, size_t to_idx) {
|
||||
dense_data_span[to_idx] = static_cast<const uint32_t*>(sparse_data)[from_idx];
|
||||
});
|
||||
|
||||
break;
|
||||
}
|
||||
case 8: {
|
||||
auto dense_data_span = gsl::make_span<uint64_t>(static_cast<uint64_t*>(dense_data), n_dense_elements);
|
||||
status = CopySparseData<uint64_t>(
|
||||
n_sparse_elements,
|
||||
indices, dims,
|
||||
[sparse_data, dense_data_span](size_t from_idx, size_t to_idx) {
|
||||
dense_data_span[to_idx] = static_cast<const uint64_t*>(sparse_data)[from_idx];
|
||||
});
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
dense.set_raw_data(std::move(dense_data_storage));
|
||||
|
||||
} else {
|
||||
// strings need to be handled differently as they can't use raw data (as per ONNX rules)
|
||||
std::vector<std::string> sparse_data(n_sparse_elements);
|
||||
UnpackTensor<std::string>(sparse_values, sparse_data.data(), n_sparse_elements);
|
||||
|
||||
// RepeatedPtrField<std::string> doesn't have a Resize method so manually add elements
|
||||
auto dense_strings = dense.mutable_string_data();
|
||||
dense_strings->Reserve(n_dense_elements);
|
||||
for (int64_t j = 0; j < n_dense_elements; ++j) {
|
||||
dense_strings->Add("");
|
||||
}
|
||||
|
||||
status = CopySparseData<std::string>(
|
||||
n_sparse_elements,
|
||||
indices, dims,
|
||||
[&sparse_values, &dense_strings](size_t from_idx, size_t to_idx) {
|
||||
const std::string& input = sparse_values.string_data()[SafeInt<int32_t>(from_idx)];
|
||||
*dense_strings->Mutable(SafeInt<int32_t>(to_idx)) = input;
|
||||
});
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
template common::Status GetSizeInBytesFromTensorProto<256>(const ONNX_NAMESPACE::TensorProto& tensor_proto,
|
||||
size_t* out);
|
||||
template common::Status GetSizeInBytesFromTensorProto<0>(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out);
|
||||
|
|
|
|||
|
|
@ -60,6 +60,14 @@ template <typename T>
|
|||
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len,
|
||||
/*out*/ T* p_data, size_t expected_size);
|
||||
|
||||
// Convert the NodeProto from a Constant node into a TensorProto that can be used as an initializer
|
||||
common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& node,
|
||||
ONNX_NAMESPACE::TensorProto& tensor);
|
||||
|
||||
// Convert a SparseTensorProto to a dense TensorProto
|
||||
common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseTensorProto& sparse,
|
||||
ONNX_NAMESPACE::TensorProto& dense);
|
||||
|
||||
inline bool HasDimValue(const ONNX_NAMESPACE::TensorShapeProto_Dimension& dim) {
|
||||
return dim.value_case() == ONNX_NAMESPACE::TensorShapeProto_Dimension::kDimValue;
|
||||
}
|
||||
|
|
@ -199,5 +207,13 @@ inline bool HasName(const ONNX_NAMESPACE::NodeProto& node_proto) {
|
|||
return node_proto.has_name();
|
||||
}
|
||||
|
||||
// UnpackTensor from either raw data or the type specific data field.
|
||||
template <typename T>
|
||||
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, /*out*/ T* p_data, size_t expected_size) {
|
||||
return HasRawData(tensor)
|
||||
? UnpackTensor(tensor, tensor.raw_data().data(), tensor.raw_data().size(), p_data, expected_size)
|
||||
: UnpackTensor(tensor, nullptr, 0, p_data, expected_size);
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -703,7 +703,7 @@ Graph::Graph(const Model& owning_model,
|
|||
const logging::Logger& logger,
|
||||
const std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*>& model_functions)
|
||||
: Graph(owning_model, graph_proto, domain_to_version, ir_version, schema_registry, nullptr, nullptr, logger,
|
||||
model_functions) {}
|
||||
model_functions) {}
|
||||
|
||||
Graph::Graph(const Model& owning_model,
|
||||
GraphProto* graph_proto, const std::unordered_map<std::string, int>& domain_to_version, Version ir_version,
|
||||
|
|
@ -731,16 +731,9 @@ Graph::Graph(const Model& owning_model,
|
|||
continue;
|
||||
}
|
||||
|
||||
// Copy constant nodes _value to name_to_initial_tensor_
|
||||
const gsl::not_null<TensorProto*> tensor{graph_proto_->add_initializer()};
|
||||
const AttributeProto& constant_attribute = node.attribute(0);
|
||||
// TODO: Add support for parsing 'sparse_value' attribute from a 'Constant' node
|
||||
// Discussion surrounding handling the SparseTensorProto must be had.
|
||||
// An easy way is to implement a method that converts a SparseTensorproto into a TensorProto
|
||||
// to use the same downstream flow, but that is going to impact peak memory usage and probably a smarter way is required.
|
||||
ORT_ENFORCE(constant_attribute.has_t(), "Only 'value' attribute is supported within a 'Constant' node in ORT");
|
||||
*tensor = constant_attribute.t();
|
||||
*(tensor->mutable_name()) = node.output(0);
|
||||
auto status = utils::ConstantNodeProtoToTensorProto(node, *tensor);
|
||||
ORT_ENFORCE(status.IsOK(), status.ToString());
|
||||
}
|
||||
|
||||
// Remove constant nodes as they're replaced with initializers above.
|
||||
|
|
|
|||
|
|
@ -8,13 +8,16 @@
|
|||
#include "core/graph/constants.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/sparse_tensor.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
|
||||
#include "core/graph/model.h"
|
||||
#include "core/session/inference_session.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test_utils.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::common;
|
||||
|
||||
|
|
@ -490,5 +493,179 @@ TEST_F(SparseTensorTests, Test2) {
|
|||
RunTest();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static std::vector<T> CreateValues() {
|
||||
return {1, 2, 3, 4};
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<std::string> CreateValues<std::string>() {
|
||||
return {"one", "two", "three", "four"};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static NodeProto CreateConstantNode(bool indices_1D,
|
||||
std::function<void(const std::vector<T>& values, TensorProto& tp)> inserter,
|
||||
std::vector<T>& expected_data) {
|
||||
NodeProto constant_node;
|
||||
constant_node.set_op_type("Constant");
|
||||
constant_node.add_output("dense_tensor_output");
|
||||
|
||||
std::vector<T> values = CreateValues<T>();
|
||||
std::vector<int64_t> indices;
|
||||
std::vector<int64_t> shape{2, 3, 2};
|
||||
|
||||
AttributeProto& attrib = *constant_node.mutable_attribute()->Add();
|
||||
attrib.set_name("sparse_value");
|
||||
attrib.set_type(AttributeProto_AttributeType_SPARSE_TENSOR);
|
||||
|
||||
SparseTensorProto& stp = *attrib.mutable_sparse_tensor();
|
||||
TensorProto& indices_tp = *stp.mutable_indices();
|
||||
|
||||
stp.mutable_dims()->Add(shape.cbegin(), shape.cend());
|
||||
for (auto dim : stp.dims())
|
||||
std::cout << dim;
|
||||
|
||||
if (indices_1D) {
|
||||
indices = {2, 5, 6, 10};
|
||||
indices_tp.add_dims(indices.size());
|
||||
} else {
|
||||
// indices are shape {NNZ, rank} so convert flattened values of 2, 5, 6 and 10 to rank 3 values
|
||||
indices_tp.add_dims(values.size());
|
||||
indices_tp.add_dims(shape.size());
|
||||
indices = {
|
||||
0, 1, 0,
|
||||
0, 2, 1,
|
||||
1, 0, 0,
|
||||
1, 2, 0};
|
||||
}
|
||||
|
||||
indices_tp.mutable_int64_data()->Add(indices.cbegin(), indices.cend());
|
||||
|
||||
expected_data.resize(2 * 3 * 2);
|
||||
expected_data[2] = values[0];
|
||||
expected_data[5] = values[1];
|
||||
expected_data[6] = values[2];
|
||||
expected_data[10] = values[3];
|
||||
|
||||
stp.mutable_values()->add_dims(values.size());
|
||||
inserter(values, *stp.mutable_values());
|
||||
|
||||
return constant_node;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void TestConversion(bool use_1D_indices,
|
||||
std::function<void(const std::vector<T>& values, TensorProto& tp)> inserter,
|
||||
std::function<void(gsl::span<const T> expected, const TensorProto& actual)> checker) {
|
||||
std::vector<T> expected;
|
||||
auto node = CreateConstantNode<T>(use_1D_indices, inserter, expected);
|
||||
|
||||
TensorProto dense;
|
||||
utils::ConstantNodeProtoToTensorProto(node, dense);
|
||||
|
||||
gsl::span<const T> expected_span = gsl::make_span<const T>(expected.data(), expected.size());
|
||||
checker(expected_span, dense);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void TestConversion(
|
||||
std::function<void(const std::vector<T>& values, TensorProto& tp)> inserter,
|
||||
std::function<void(gsl::span<const T> expected, const TensorProto& actual)> checker) {
|
||||
TestConversion(true, inserter, checker);
|
||||
TestConversion(false, inserter, checker);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void RawDataWriter(const std::vector<T>& values, TensorProto& tp, TensorProto_DataType datatype) {
|
||||
tp.set_data_type(datatype);
|
||||
tp.set_raw_data(values.data(), values.size() * sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void RawDataChecker(gsl::span<const T> expected, const TensorProto& actual) {
|
||||
int64_t actual_size = 1;
|
||||
for (const auto dim : actual.dims()) {
|
||||
actual_size *= dim;
|
||||
}
|
||||
|
||||
const T* raw_data = reinterpret_cast<const T*>(actual.raw_data().data());
|
||||
auto actual_span = gsl::make_span<const T>(raw_data, actual_size);
|
||||
|
||||
EXPECT_THAT(actual_span, testing::ContainerEq(expected));
|
||||
}
|
||||
|
||||
TEST(SparseTensorConversionTests, TestConstantNodeConversion) {
|
||||
TestConversion<float>(
|
||||
[](const std::vector<float>& values, TensorProto& tp) {
|
||||
tp.set_data_type(TensorProto_DataType_FLOAT);
|
||||
tp.mutable_float_data()->Add(values.cbegin(), values.cend());
|
||||
},
|
||||
RawDataChecker<float>);
|
||||
|
||||
TestConversion<int32_t>(
|
||||
[](const std::vector<int32_t>& values, TensorProto& tp) {
|
||||
tp.set_data_type(TensorProto_DataType_INT32);
|
||||
tp.mutable_int32_data()->Add(values.cbegin(), values.cend());
|
||||
},
|
||||
RawDataChecker<int32_t>);
|
||||
|
||||
TestConversion<int64_t>(
|
||||
[](const std::vector<int64_t>& values, TensorProto& tp) {
|
||||
tp.set_data_type(TensorProto_DataType_INT64);
|
||||
tp.mutable_int64_data()->Add(values.cbegin(), values.cend());
|
||||
},
|
||||
RawDataChecker<int64_t>);
|
||||
|
||||
TestConversion<double>(
|
||||
[](const std::vector<double>& values, TensorProto& tp) {
|
||||
tp.set_data_type(TensorProto_DataType_DOUBLE);
|
||||
tp.mutable_double_data()->Add(values.cbegin(), values.cend());
|
||||
},
|
||||
RawDataChecker<double>);
|
||||
|
||||
TestConversion<uint32_t>(
|
||||
[](const std::vector<uint32_t>& values, TensorProto& tp) {
|
||||
tp.set_data_type(TensorProto_DataType_UINT32);
|
||||
tp.mutable_uint64_data()->Add(values.cbegin(), values.cend()); // stored in uint64_data despite being uint32_t
|
||||
},
|
||||
RawDataChecker<uint32_t>);
|
||||
|
||||
TestConversion<uint64_t>(
|
||||
[](const std::vector<uint64_t>& values, TensorProto& tp) {
|
||||
tp.set_data_type(TensorProto_DataType_UINT64);
|
||||
tp.mutable_uint64_data()->Add(values.cbegin(), values.cend());
|
||||
},
|
||||
RawDataChecker<uint64_t>);
|
||||
|
||||
// test a couple of types with values in raw data field
|
||||
TestConversion<float>(
|
||||
[](const std::vector<float>& values, TensorProto& tp) {
|
||||
RawDataWriter(values, tp, TensorProto_DataType_FLOAT);
|
||||
},
|
||||
RawDataChecker<float>);
|
||||
|
||||
TestConversion<int64_t>(
|
||||
[](const std::vector<int64_t>& values, TensorProto& tp) {
|
||||
RawDataWriter(values, tp, TensorProto_DataType_INT64);
|
||||
},
|
||||
RawDataChecker<int64_t>);
|
||||
|
||||
// strings can't use raw data, and string_data is a RepeatedPtrField (vs. RepeatedField for simple types)
|
||||
// so has to be handled differently
|
||||
TestConversion<std::string>(
|
||||
[](const std::vector<std::string>& values, TensorProto& tp) {
|
||||
tp.set_data_type(TensorProto_DataType_STRING);
|
||||
for (auto cur = values.cbegin(), end = values.cend(); cur < end; ++cur) {
|
||||
tp.mutable_string_data()->Add(std::string(*cur));
|
||||
}
|
||||
},
|
||||
[](gsl::span<const std::string> expected, const TensorProto& actual) {
|
||||
const auto& actual_strings = actual.string_data();
|
||||
for (int64_t i = 0, end = expected.size(); i < end; ++i) {
|
||||
EXPECT_EQ(actual_strings[static_cast<int32_t>(i)], expected[i]);
|
||||
}
|
||||
});
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -3,21 +3,16 @@
|
|||
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
#include "test/util/include/asserts.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
using namespace ::onnxruntime::utils;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
namespace {
|
||||
template <typename T>
|
||||
Status UnpackTensorWrapper(const ONNX_NAMESPACE::TensorProto& tensor, /*out*/ T* p_data, int64_t expected_size) {
|
||||
if (tensor.has_raw_data())
|
||||
return UnpackTensor(tensor, tensor.raw_data().data(), tensor.raw_data().size(), p_data, expected_size);
|
||||
return UnpackTensor(tensor, nullptr, 0, p_data, expected_size);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
//T must be float for double, and it must match with the 'type' argument
|
||||
template <typename T>
|
||||
|
|
@ -32,7 +27,7 @@ void test_unpack_float_tensor(TensorProto_DataType type) {
|
|||
}
|
||||
float_tensor_proto.set_raw_data(rawdata, len);
|
||||
T float_data2[4];
|
||||
auto status = UnpackTensorWrapper(float_tensor_proto, float_data2, 4);
|
||||
auto status = UnpackTensor(float_tensor_proto, float_data2, 4);
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
EXPECT_EQ(1.1f, float_data2[0]);
|
||||
EXPECT_EQ(2.2f, float_data2[1]);
|
||||
|
|
@ -46,12 +41,12 @@ TEST(TensorParseTest, TensorUtilsTest) {
|
|||
bool_tensor_proto.add_int32_data(1);
|
||||
|
||||
bool bool_data[1];
|
||||
auto status = UnpackTensorWrapper(bool_tensor_proto, bool_data, 1);
|
||||
auto status = UnpackTensor(bool_tensor_proto, bool_data, 1);
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
EXPECT_TRUE(bool_data[0]);
|
||||
|
||||
float float_data[1];
|
||||
status = UnpackTensorWrapper(bool_tensor_proto, float_data, 1);
|
||||
status = UnpackTensor(bool_tensor_proto, float_data, 1);
|
||||
EXPECT_FALSE(status.IsOK());
|
||||
|
||||
test_unpack_float_tensor<float>(TensorProto_DataType_FLOAT);
|
||||
|
|
@ -63,13 +58,121 @@ TEST(TensorParseTest, TensorUtilsTest) {
|
|||
string_tensor_proto.add_string_data("b");
|
||||
|
||||
std::string string_data[2];
|
||||
status = UnpackTensorWrapper(string_tensor_proto, string_data, 2);
|
||||
status = UnpackTensor(string_tensor_proto, string_data, 2);
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
EXPECT_EQ("a", string_data[0]);
|
||||
EXPECT_EQ("b", string_data[1]);
|
||||
|
||||
status = UnpackTensorWrapper(bool_tensor_proto, string_data, 2);
|
||||
status = UnpackTensor(bool_tensor_proto, string_data, 2);
|
||||
EXPECT_FALSE(status.IsOK());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static std::vector<T> CreateValues() {
|
||||
return {1, 2, 3, 4};
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<std::string> CreateValues<std::string>() {
|
||||
return {"one", "two", "three", "four"};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static NodeProto CreateConstantNode(const std::string& attrib_name, AttributeProto_AttributeType type,
|
||||
std::function<void(AttributeProto&)> add_data) {
|
||||
NodeProto constant_node;
|
||||
constant_node.set_op_type("Constant");
|
||||
constant_node.add_output("Constant_output");
|
||||
|
||||
AttributeProto& attrib = *constant_node.mutable_attribute()->Add();
|
||||
|
||||
attrib.set_name(attrib_name);
|
||||
attrib.set_type(type);
|
||||
add_data(attrib);
|
||||
|
||||
return constant_node;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void TestConstantNodeConversion(const std::string& attrib_name,
|
||||
AttributeProto_AttributeType type,
|
||||
std::function<void(AttributeProto&, const std::vector<T>& data)> add_data,
|
||||
std::function<std::vector<T>(const TensorProto&)> get_data,
|
||||
int64_t num_elements) {
|
||||
auto input = CreateValues<T>();
|
||||
if (num_elements == -1) {
|
||||
num_elements = static_cast<int64_t>(input.size());
|
||||
} else {
|
||||
input.resize(num_elements);
|
||||
}
|
||||
|
||||
auto c = CreateConstantNode<T>(
|
||||
attrib_name, type,
|
||||
[&input, &add_data](AttributeProto& attrib) { add_data(attrib, input); });
|
||||
|
||||
TensorProto tp;
|
||||
EXPECT_STATUS_OK(utils::ConstantNodeProtoToTensorProto(c, tp));
|
||||
|
||||
EXPECT_THAT(get_data(tp), ::testing::ContainerEq(input));
|
||||
}
|
||||
|
||||
TEST(TensorProtoUtilsTest, ConstantTensorProto) {
|
||||
TestConstantNodeConversion<float>(
|
||||
"value_float", AttributeProto_AttributeType_FLOAT,
|
||||
[](AttributeProto& attrib, const std::vector<float>& data) { attrib.set_f(data[0]); },
|
||||
[](const TensorProto& tp) {
|
||||
return std::vector<float>(tp.float_data().cbegin(), tp.float_data().cend());
|
||||
},
|
||||
1);
|
||||
|
||||
TestConstantNodeConversion<float>(
|
||||
"value_floats", AttributeProto_AttributeType_FLOATS,
|
||||
[](AttributeProto& attrib, const std::vector<float>& data) {
|
||||
*attrib.mutable_floats() = {data.cbegin(), data.cend()};
|
||||
},
|
||||
[](const TensorProto& tp) {
|
||||
return std::vector<float>(tp.float_data().cbegin(), tp.float_data().cend());
|
||||
},
|
||||
-1);
|
||||
|
||||
TestConstantNodeConversion<int64_t>(
|
||||
"value_int", AttributeProto_AttributeType_INT,
|
||||
[](AttributeProto& attrib, const std::vector<int64_t>& data) { attrib.set_i(data[0]); },
|
||||
[](const TensorProto& tp) {
|
||||
return std::vector<int64_t>(tp.int64_data().cbegin(), tp.int64_data().cend());
|
||||
},
|
||||
1);
|
||||
|
||||
TestConstantNodeConversion<int64_t>(
|
||||
"value_ints", AttributeProto_AttributeType_INTS,
|
||||
[](AttributeProto& attrib, const std::vector<int64_t>& data) {
|
||||
*attrib.mutable_ints() = {data.cbegin(), data.cend()};
|
||||
},
|
||||
[](const TensorProto& tp) {
|
||||
return std::vector<int64_t>(tp.int64_data().cbegin(), tp.int64_data().cend());
|
||||
},
|
||||
-1);
|
||||
|
||||
TestConstantNodeConversion<std::string>(
|
||||
"value_string", AttributeProto_AttributeType_STRING,
|
||||
[](AttributeProto& attrib, const std::vector<std::string>& data) { attrib.set_s(data[0]); },
|
||||
[](const TensorProto& tp) {
|
||||
return std::vector<std::string>(tp.string_data().cbegin(), tp.string_data().cend());
|
||||
},
|
||||
1);
|
||||
|
||||
TestConstantNodeConversion<std::string>(
|
||||
"value_strings", AttributeProto_AttributeType_STRINGS,
|
||||
[](AttributeProto& attrib, const std::vector<std::string>& data) {
|
||||
// for (const auto& s : data)
|
||||
*attrib.mutable_strings() = {data.cbegin(), data.cend()};
|
||||
},
|
||||
[](const TensorProto& tp) {
|
||||
return std::vector<std::string>(tp.string_data().cbegin(), tp.string_data().cend());
|
||||
},
|
||||
-1);
|
||||
|
||||
// sparse_tensor is covered by SparseTensorConversionTests.TestConstantNodeConversion
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue