Constant-12 support (#3304)

1. Support the new fields for Constant in opset 12
2. Support SparseTensor in the Constant node by converting to dense tensor when lifting the Constant to an initializer. Will make a model with a sparse tensor in a Constant work but isn't an overly efficient approach.
This commit is contained in:
Scott McKay 2020-03-31 16:13:52 +10:00 committed by GitHub
parent 2332a93db0
commit ace741680d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 570 additions and 54 deletions

View file

@ -59,28 +59,28 @@ namespace onnxruntime {
namespace utils {
// This macro doesn't work for Float16/bool/string tensors
#define DEFINE_UNPACK_TENSOR(T, Type, field_name, field_size) \
template <> \
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, \
/*out*/ T* p_data, size_t expected_size) { \
if (nullptr == p_data) { \
const size_t size = raw_data != nullptr ? raw_data_len : tensor.field_size(); \
if (size == 0) return Status::OK(); \
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \
} \
if (nullptr == p_data || Type != tensor.data_type()) { \
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \
} \
if (raw_data != nullptr) { \
return UnpackTensorWithRawData(raw_data, raw_data_len, expected_size, p_data); \
} \
if (static_cast<size_t>(tensor.field_size()) != expected_size) \
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "corrupted protobuf data: tensor shape size(", expected_size, \
") does not match the data size(", tensor.field_size(), ") in proto"); \
auto& data = tensor.field_name(); \
for (auto data_iter = data.cbegin(); data_iter != data.cend(); ++data_iter) \
*p_data++ = *reinterpret_cast<const T*>(data_iter); \
return Status::OK(); \
#define DEFINE_UNPACK_TENSOR(T, Type, field_name, field_size) \
template <> \
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, \
/*out*/ T* p_data, size_t expected_size) { \
if (nullptr == p_data) { \
const size_t size = raw_data != nullptr ? raw_data_len : tensor.field_size(); \
if (size == 0) return Status::OK(); \
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \
} \
if (nullptr == p_data || Type != tensor.data_type()) { \
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \
} \
if (raw_data != nullptr) { \
return UnpackTensorWithRawData(raw_data, raw_data_len, expected_size, p_data); \
} \
if (static_cast<size_t>(tensor.field_size()) != expected_size) \
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "corrupted protobuf data: tensor shape size(", \
expected_size, ") does not match the data size(", tensor.field_size(), ") in proto"); \
auto& data = tensor.field_name(); \
for (auto data_iter = data.cbegin(); data_iter != data.cend(); ++data_iter) \
*p_data++ = *reinterpret_cast<const T*>(data_iter); \
return Status::OK(); \
}
// TODO: complex64 complex128
@ -310,10 +310,10 @@ ORT_API(void, OrtUninitializeBuffer, _In_opt_ void* input, size_t input_len, enu
}
}
#define CASE_PROTO(X, Y) \
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \
ORT_RETURN_IF_ERROR( \
::onnxruntime::utils::UnpackTensor<Y>(tensor_proto, raw_data, raw_data_len, (Y*)preallocated, static_cast<size_t>(tensor_size))); \
#define CASE_PROTO(X, Y) \
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \
ORT_RETURN_IF_ERROR( \
UnpackTensor<Y>(tensor_proto, raw_data, raw_data_len, (Y*)preallocated, static_cast<size_t>(tensor_size))); \
break;
class AutoDelete {
@ -421,7 +421,8 @@ Status TensorProtoToMLValue(const Env& env, const ORTCHAR_T* tensor_proto_path,
int64_t tensor_size = 1;
{
for (auto i : tensor_proto.dims()) {
if (i < 0) return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "tensor can't contain negative dims");
if (i < 0)
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "tensor can't contain negative dims");
tensor_size *= i;
}
}
@ -464,8 +465,8 @@ Status TensorProtoToMLValue(const Env& env, const ORTCHAR_T* tensor_proto_path,
deleter.f = UnInitTensor;
deleter.param = new UnInitializeParam{preallocated, preallocated_size, ele_type};
}
ORT_RETURN_IF_ERROR(::onnxruntime::utils::UnpackTensor<std::string>(tensor_proto, raw_data, raw_data_len,
(std::string*)preallocated, static_cast<size_t>(tensor_size)));
ORT_RETURN_IF_ERROR(UnpackTensor<std::string>(tensor_proto, raw_data, raw_data_len,
(std::string*)preallocated, static_cast<size_t>(tensor_size)));
break;
default: {
std::ostringstream ostr;
@ -536,12 +537,238 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std:
//tensor_proto.set_data_type(utils::GetTensorProtoType(tensor));
tensor_proto.set_data_type(tensor_proto_type.tensor_type().elem_type());
tensor_proto.set_raw_data(tensor.DataRaw(), tensor.SizeInBytes());
return tensor_proto;
}
common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& node,
ONNX_NAMESPACE::TensorProto& tensor) {
const AttributeProto& constant_attribute = node.attribute(0);
switch (constant_attribute.type()) {
case AttributeProto_AttributeType_TENSOR:
tensor = constant_attribute.t();
break;
case AttributeProto_AttributeType_FLOAT:
tensor.set_data_type(TensorProto_DataType_FLOAT);
tensor.add_float_data(constant_attribute.f());
break;
case AttributeProto_AttributeType_FLOATS:
tensor.set_data_type(TensorProto_DataType_FLOAT);
*tensor.mutable_float_data() = constant_attribute.floats();
break;
case AttributeProto_AttributeType_INT:
tensor.set_data_type(TensorProto_DataType_INT64);
tensor.add_int64_data(constant_attribute.i());
break;
case AttributeProto_AttributeType_INTS:
tensor.set_data_type(TensorProto_DataType_INT64);
*tensor.mutable_int64_data() = constant_attribute.ints();
break;
case AttributeProto_AttributeType_STRING:
tensor.set_data_type(TensorProto_DataType_STRING);
tensor.add_string_data(constant_attribute.s());
break;
case AttributeProto_AttributeType_STRINGS: {
tensor.set_data_type(TensorProto_DataType_STRING);
*tensor.mutable_string_data() = constant_attribute.strings();
break;
}
case AttributeProto_AttributeType_SPARSE_TENSOR: {
auto& s = constant_attribute.sparse_tensor();
ORT_RETURN_IF_ERROR(SparseTensorProtoToDenseTensorProto(s, tensor));
break;
}
default:
ORT_THROW("Unsupported attribute value type of ", constant_attribute.type(),
" in 'Constant' node '", node.name(), "'");
}
// set name last in case attribute type was tensor (would copy over name)
*(tensor.mutable_name()) = node.output(0);
return Status::OK();
}
template <typename T>
static Status CopySparseData(size_t n_sparse_elements,
const ONNX_NAMESPACE::TensorProto& indices,
gsl::span<const int64_t> dims,
std::function<void(size_t from_idx, size_t to_idx)> copier) {
Status status = Status::OK();
TensorShape indices_shape(indices.dims().data(), indices.dims().size());
auto indices_data = gsl::make_span<const int64_t>(indices.int64_data().data(), static_cast<size_t>(indices_shape.Size()));
if (indices_shape.NumDimensions() == 1) {
// flattened indexes
for (size_t i = 0; i < n_sparse_elements; ++i) {
copier(i, static_cast<size_t>(indices_data[i]));
}
} else if (indices_shape.NumDimensions() == 2) {
// entries in format {NNZ, rank}
size_t rank = static_cast<size_t>(indices_shape[1]);
ORT_ENFORCE(rank == dims.size() && rank > 0);
const int64_t* cur_index = indices_data.data();
std::vector<size_t> multipliers;
multipliers.resize(rank);
// calculate sum of inner dimension elements for each dimension.
// e.g. if shape {2,3,4}, the result should be {3*4, 4, 1}
multipliers[rank - 1] = 1;
for (int32_t r = static_cast<int32_t>(rank) - 2; r >= 0; --r) {
multipliers[r] = static_cast<size_t>(dims[r + 1]) * multipliers[r + 1];
}
// calculate the offset for the entry
// e.g. if shape was {2,3,4} and entry was (1, 0, 2) the offset is 14
// as there are 2 rows, each with 12 entries per row
for (size_t i = 0; i < n_sparse_elements; ++i) {
size_t idx = 0;
for (size_t j = 0; j < rank; ++j) {
idx += static_cast<size_t>(cur_index[j]) * multipliers[j];
}
copier(i, idx);
cur_index += rank;
}
ORT_ENFORCE(cur_index == &*indices_data.cend());
} else {
status = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Invalid SparseTensor indices. Should be rank 0 or 1. Got:",
indices_shape);
}
return status;
}
common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseTensorProto& sparse,
ONNX_NAMESPACE::TensorProto& dense) {
Status status = Status::OK();
const auto& sparse_values = sparse.values();
auto type = sparse_values.data_type();
dense.set_data_type(type);
SafeInt<size_t> n_sparse_elements = 1;
for (auto dim : sparse_values.dims()) {
n_sparse_elements *= dim;
}
SafeInt<size_t> n_dense_elements = 1;
for (auto dim : sparse.dims()) {
n_dense_elements *= dim;
dense.add_dims(dim);
}
const auto& indices = sparse.indices();
auto dims = gsl::make_span<const int64_t>(dense.dims().data(), dense.dims().size());
// need to read in sparse data first as it could be in a type specific field, in raw data, or in external data
size_t sparse_bytes;
ORT_RETURN_IF_ERROR(GetSizeInBytesFromTensorProto<0>(sparse_values, &sparse_bytes));
if (type != TensorProto_DataType_STRING) {
std::vector<unsigned char> sparse_data_storage(sparse_bytes, 0);
void* sparse_data = sparse_data_storage.data();
size_t element_size = 0;
// setup buffer for output
switch (type) {
case TensorProto_DataType_FLOAT: {
element_size = sizeof(float);
UnpackTensor<float>(sparse_values, static_cast<float*>(sparse_data), n_sparse_elements);
break;
}
case TensorProto_DataType_INT64: {
element_size = sizeof(int64_t);
UnpackTensor<int64_t>(sparse_values, static_cast<int64_t*>(sparse_data), n_sparse_elements);
break;
}
case TensorProto_DataType_INT32: {
element_size = sizeof(int32_t);
UnpackTensor<int32_t>(sparse_values, static_cast<int32_t*>(sparse_data), n_sparse_elements);
break;
}
case TensorProto_DataType_DOUBLE: {
element_size = sizeof(double);
UnpackTensor<double>(sparse_values, static_cast<double*>(sparse_data), n_sparse_elements);
break;
}
case TensorProto_DataType_UINT32: {
element_size = sizeof(uint32_t);
UnpackTensor<uint32_t>(sparse_values, static_cast<uint32_t*>(sparse_data), n_sparse_elements);
break;
}
case TensorProto_DataType_UINT64: {
element_size = sizeof(uint64_t);
UnpackTensor<uint64_t>(sparse_values, static_cast<uint64_t*>(sparse_data), n_sparse_elements);
break;
}
default:
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported sparse tensor data type of ", type);
}
// by putting the data into a std::string we can avoid a copy as set_raw_data can do a std::move
// into the TensorProto. however to actually write to the buffer we have created in the std::string we need
// this somewhat dirty hack to get a mutable pointer. we could alternatively use &dense_data_storage.front()
// but using const_cast makes it more obvious we're doing something ugly.
std::string dense_data_storage(n_dense_elements * element_size, 0);
void* dense_data = const_cast<char*>(dense_data_storage.data());
switch (element_size) {
case 4: {
auto dense_data_span = gsl::make_span<uint32_t>(static_cast<uint32_t*>(dense_data), n_dense_elements);
status = CopySparseData<uint32_t>(
n_sparse_elements,
indices, dims,
[sparse_data, dense_data_span](size_t from_idx, size_t to_idx) {
dense_data_span[to_idx] = static_cast<const uint32_t*>(sparse_data)[from_idx];
});
break;
}
case 8: {
auto dense_data_span = gsl::make_span<uint64_t>(static_cast<uint64_t*>(dense_data), n_dense_elements);
status = CopySparseData<uint64_t>(
n_sparse_elements,
indices, dims,
[sparse_data, dense_data_span](size_t from_idx, size_t to_idx) {
dense_data_span[to_idx] = static_cast<const uint64_t*>(sparse_data)[from_idx];
});
break;
}
}
dense.set_raw_data(std::move(dense_data_storage));
} else {
// strings need to be handled differently as they can't use raw data (as per ONNX rules)
std::vector<std::string> sparse_data(n_sparse_elements);
UnpackTensor<std::string>(sparse_values, sparse_data.data(), n_sparse_elements);
// RepeatedPtrField<std::string> doesn't have a Resize method so manually add elements
auto dense_strings = dense.mutable_string_data();
dense_strings->Reserve(n_dense_elements);
for (int64_t j = 0; j < n_dense_elements; ++j) {
dense_strings->Add("");
}
status = CopySparseData<std::string>(
n_sparse_elements,
indices, dims,
[&sparse_values, &dense_strings](size_t from_idx, size_t to_idx) {
const std::string& input = sparse_values.string_data()[SafeInt<int32_t>(from_idx)];
*dense_strings->Mutable(SafeInt<int32_t>(to_idx)) = input;
});
}
return status;
}
template common::Status GetSizeInBytesFromTensorProto<256>(const ONNX_NAMESPACE::TensorProto& tensor_proto,
size_t* out);
template common::Status GetSizeInBytesFromTensorProto<0>(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out);

View file

@ -60,6 +60,14 @@ template <typename T>
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len,
/*out*/ T* p_data, size_t expected_size);
// Convert the NodeProto from a Constant node into a TensorProto that can be used as an initializer
common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& node,
ONNX_NAMESPACE::TensorProto& tensor);
// Convert a SparseTensorProto to a dense TensorProto
common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseTensorProto& sparse,
ONNX_NAMESPACE::TensorProto& dense);
inline bool HasDimValue(const ONNX_NAMESPACE::TensorShapeProto_Dimension& dim) {
return dim.value_case() == ONNX_NAMESPACE::TensorShapeProto_Dimension::kDimValue;
}
@ -199,5 +207,13 @@ inline bool HasName(const ONNX_NAMESPACE::NodeProto& node_proto) {
return node_proto.has_name();
}
// UnpackTensor from either raw data or the type specific data field.
template <typename T>
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, /*out*/ T* p_data, size_t expected_size) {
return HasRawData(tensor)
? UnpackTensor(tensor, tensor.raw_data().data(), tensor.raw_data().size(), p_data, expected_size)
: UnpackTensor(tensor, nullptr, 0, p_data, expected_size);
}
} // namespace utils
} // namespace onnxruntime

View file

@ -703,7 +703,7 @@ Graph::Graph(const Model& owning_model,
const logging::Logger& logger,
const std::unordered_map<std::string, const ONNX_NAMESPACE::FunctionProto*>& model_functions)
: Graph(owning_model, graph_proto, domain_to_version, ir_version, schema_registry, nullptr, nullptr, logger,
model_functions) {}
model_functions) {}
Graph::Graph(const Model& owning_model,
GraphProto* graph_proto, const std::unordered_map<std::string, int>& domain_to_version, Version ir_version,
@ -731,16 +731,9 @@ Graph::Graph(const Model& owning_model,
continue;
}
// Copy constant nodes _value to name_to_initial_tensor_
const gsl::not_null<TensorProto*> tensor{graph_proto_->add_initializer()};
const AttributeProto& constant_attribute = node.attribute(0);
// TODO: Add support for parsing 'sparse_value' attribute from a 'Constant' node
// Discussion surrounding handling the SparseTensorProto must be had.
// An easy way is to implement a method that converts a SparseTensorproto into a TensorProto
// to use the same downstream flow, but that is going to impact peak memory usage and probably a smarter way is required.
ORT_ENFORCE(constant_attribute.has_t(), "Only 'value' attribute is supported within a 'Constant' node in ORT");
*tensor = constant_attribute.t();
*(tensor->mutable_name()) = node.output(0);
auto status = utils::ConstantNodeProtoToTensorProto(node, *tensor);
ORT_ENFORCE(status.IsOK(), status.ToString());
}
// Remove constant nodes as they're replaced with initializers above.

View file

@ -8,13 +8,16 @@
#include "core/graph/constants.h"
#include "core/framework/op_kernel.h"
#include "core/framework/sparse_tensor.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/model.h"
#include "core/session/inference_session.h"
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
#include "test_utils.h"
#include "gtest/gtest.h"
#include "gmock/gmock.h"
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
@ -490,5 +493,179 @@ TEST_F(SparseTensorTests, Test2) {
RunTest();
}
template <typename T>
static std::vector<T> CreateValues() {
return {1, 2, 3, 4};
}
template <>
std::vector<std::string> CreateValues<std::string>() {
return {"one", "two", "three", "four"};
}
template <typename T>
static NodeProto CreateConstantNode(bool indices_1D,
std::function<void(const std::vector<T>& values, TensorProto& tp)> inserter,
std::vector<T>& expected_data) {
NodeProto constant_node;
constant_node.set_op_type("Constant");
constant_node.add_output("dense_tensor_output");
std::vector<T> values = CreateValues<T>();
std::vector<int64_t> indices;
std::vector<int64_t> shape{2, 3, 2};
AttributeProto& attrib = *constant_node.mutable_attribute()->Add();
attrib.set_name("sparse_value");
attrib.set_type(AttributeProto_AttributeType_SPARSE_TENSOR);
SparseTensorProto& stp = *attrib.mutable_sparse_tensor();
TensorProto& indices_tp = *stp.mutable_indices();
stp.mutable_dims()->Add(shape.cbegin(), shape.cend());
for (auto dim : stp.dims())
std::cout << dim;
if (indices_1D) {
indices = {2, 5, 6, 10};
indices_tp.add_dims(indices.size());
} else {
// indices are shape {NNZ, rank} so convert flattened values of 2, 5, 6 and 10 to rank 3 values
indices_tp.add_dims(values.size());
indices_tp.add_dims(shape.size());
indices = {
0, 1, 0,
0, 2, 1,
1, 0, 0,
1, 2, 0};
}
indices_tp.mutable_int64_data()->Add(indices.cbegin(), indices.cend());
expected_data.resize(2 * 3 * 2);
expected_data[2] = values[0];
expected_data[5] = values[1];
expected_data[6] = values[2];
expected_data[10] = values[3];
stp.mutable_values()->add_dims(values.size());
inserter(values, *stp.mutable_values());
return constant_node;
}
template <typename T>
static void TestConversion(bool use_1D_indices,
std::function<void(const std::vector<T>& values, TensorProto& tp)> inserter,
std::function<void(gsl::span<const T> expected, const TensorProto& actual)> checker) {
std::vector<T> expected;
auto node = CreateConstantNode<T>(use_1D_indices, inserter, expected);
TensorProto dense;
utils::ConstantNodeProtoToTensorProto(node, dense);
gsl::span<const T> expected_span = gsl::make_span<const T>(expected.data(), expected.size());
checker(expected_span, dense);
}
template <typename T>
static void TestConversion(
std::function<void(const std::vector<T>& values, TensorProto& tp)> inserter,
std::function<void(gsl::span<const T> expected, const TensorProto& actual)> checker) {
TestConversion(true, inserter, checker);
TestConversion(false, inserter, checker);
}
template <typename T>
static void RawDataWriter(const std::vector<T>& values, TensorProto& tp, TensorProto_DataType datatype) {
tp.set_data_type(datatype);
tp.set_raw_data(values.data(), values.size() * sizeof(T));
}
template <typename T>
static void RawDataChecker(gsl::span<const T> expected, const TensorProto& actual) {
int64_t actual_size = 1;
for (const auto dim : actual.dims()) {
actual_size *= dim;
}
const T* raw_data = reinterpret_cast<const T*>(actual.raw_data().data());
auto actual_span = gsl::make_span<const T>(raw_data, actual_size);
EXPECT_THAT(actual_span, testing::ContainerEq(expected));
}
TEST(SparseTensorConversionTests, TestConstantNodeConversion) {
TestConversion<float>(
[](const std::vector<float>& values, TensorProto& tp) {
tp.set_data_type(TensorProto_DataType_FLOAT);
tp.mutable_float_data()->Add(values.cbegin(), values.cend());
},
RawDataChecker<float>);
TestConversion<int32_t>(
[](const std::vector<int32_t>& values, TensorProto& tp) {
tp.set_data_type(TensorProto_DataType_INT32);
tp.mutable_int32_data()->Add(values.cbegin(), values.cend());
},
RawDataChecker<int32_t>);
TestConversion<int64_t>(
[](const std::vector<int64_t>& values, TensorProto& tp) {
tp.set_data_type(TensorProto_DataType_INT64);
tp.mutable_int64_data()->Add(values.cbegin(), values.cend());
},
RawDataChecker<int64_t>);
TestConversion<double>(
[](const std::vector<double>& values, TensorProto& tp) {
tp.set_data_type(TensorProto_DataType_DOUBLE);
tp.mutable_double_data()->Add(values.cbegin(), values.cend());
},
RawDataChecker<double>);
TestConversion<uint32_t>(
[](const std::vector<uint32_t>& values, TensorProto& tp) {
tp.set_data_type(TensorProto_DataType_UINT32);
tp.mutable_uint64_data()->Add(values.cbegin(), values.cend()); // stored in uint64_data despite being uint32_t
},
RawDataChecker<uint32_t>);
TestConversion<uint64_t>(
[](const std::vector<uint64_t>& values, TensorProto& tp) {
tp.set_data_type(TensorProto_DataType_UINT64);
tp.mutable_uint64_data()->Add(values.cbegin(), values.cend());
},
RawDataChecker<uint64_t>);
// test a couple of types with values in raw data field
TestConversion<float>(
[](const std::vector<float>& values, TensorProto& tp) {
RawDataWriter(values, tp, TensorProto_DataType_FLOAT);
},
RawDataChecker<float>);
TestConversion<int64_t>(
[](const std::vector<int64_t>& values, TensorProto& tp) {
RawDataWriter(values, tp, TensorProto_DataType_INT64);
},
RawDataChecker<int64_t>);
// strings can't use raw data, and string_data is a RepeatedPtrField (vs. RepeatedField for simple types)
// so has to be handled differently
TestConversion<std::string>(
[](const std::vector<std::string>& values, TensorProto& tp) {
tp.set_data_type(TensorProto_DataType_STRING);
for (auto cur = values.cbegin(), end = values.cend(); cur < end; ++cur) {
tp.mutable_string_data()->Add(std::string(*cur));
}
},
[](gsl::span<const std::string> expected, const TensorProto& actual) {
const auto& actual_strings = actual.string_data();
for (int64_t i = 0, end = expected.size(); i < end; ++i) {
EXPECT_EQ(actual_strings[static_cast<int32_t>(i)], expected[i]);
}
});
}
} // namespace test
} // namespace onnxruntime

View file

@ -3,21 +3,16 @@
#include "core/framework/tensorprotoutils.h"
#include "core/graph/onnx_protobuf.h"
#include "test/util/include/asserts.h"
#include "gtest/gtest.h"
#include "gmock/gmock.h"
using namespace ::onnxruntime::utils;
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
namespace test {
namespace {
template <typename T>
Status UnpackTensorWrapper(const ONNX_NAMESPACE::TensorProto& tensor, /*out*/ T* p_data, int64_t expected_size) {
if (tensor.has_raw_data())
return UnpackTensor(tensor, tensor.raw_data().data(), tensor.raw_data().size(), p_data, expected_size);
return UnpackTensor(tensor, nullptr, 0, p_data, expected_size);
}
} // namespace
//T must be float for double, and it must match with the 'type' argument
template <typename T>
@ -32,7 +27,7 @@ void test_unpack_float_tensor(TensorProto_DataType type) {
}
float_tensor_proto.set_raw_data(rawdata, len);
T float_data2[4];
auto status = UnpackTensorWrapper(float_tensor_proto, float_data2, 4);
auto status = UnpackTensor(float_tensor_proto, float_data2, 4);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
EXPECT_EQ(1.1f, float_data2[0]);
EXPECT_EQ(2.2f, float_data2[1]);
@ -46,12 +41,12 @@ TEST(TensorParseTest, TensorUtilsTest) {
bool_tensor_proto.add_int32_data(1);
bool bool_data[1];
auto status = UnpackTensorWrapper(bool_tensor_proto, bool_data, 1);
auto status = UnpackTensor(bool_tensor_proto, bool_data, 1);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
EXPECT_TRUE(bool_data[0]);
float float_data[1];
status = UnpackTensorWrapper(bool_tensor_proto, float_data, 1);
status = UnpackTensor(bool_tensor_proto, float_data, 1);
EXPECT_FALSE(status.IsOK());
test_unpack_float_tensor<float>(TensorProto_DataType_FLOAT);
@ -63,13 +58,121 @@ TEST(TensorParseTest, TensorUtilsTest) {
string_tensor_proto.add_string_data("b");
std::string string_data[2];
status = UnpackTensorWrapper(string_tensor_proto, string_data, 2);
status = UnpackTensor(string_tensor_proto, string_data, 2);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
EXPECT_EQ("a", string_data[0]);
EXPECT_EQ("b", string_data[1]);
status = UnpackTensorWrapper(bool_tensor_proto, string_data, 2);
status = UnpackTensor(bool_tensor_proto, string_data, 2);
EXPECT_FALSE(status.IsOK());
}
template <typename T>
static std::vector<T> CreateValues() {
return {1, 2, 3, 4};
}
template <>
std::vector<std::string> CreateValues<std::string>() {
return {"one", "two", "three", "four"};
}
template <typename T>
static NodeProto CreateConstantNode(const std::string& attrib_name, AttributeProto_AttributeType type,
std::function<void(AttributeProto&)> add_data) {
NodeProto constant_node;
constant_node.set_op_type("Constant");
constant_node.add_output("Constant_output");
AttributeProto& attrib = *constant_node.mutable_attribute()->Add();
attrib.set_name(attrib_name);
attrib.set_type(type);
add_data(attrib);
return constant_node;
}
template <typename T>
static void TestConstantNodeConversion(const std::string& attrib_name,
AttributeProto_AttributeType type,
std::function<void(AttributeProto&, const std::vector<T>& data)> add_data,
std::function<std::vector<T>(const TensorProto&)> get_data,
int64_t num_elements) {
auto input = CreateValues<T>();
if (num_elements == -1) {
num_elements = static_cast<int64_t>(input.size());
} else {
input.resize(num_elements);
}
auto c = CreateConstantNode<T>(
attrib_name, type,
[&input, &add_data](AttributeProto& attrib) { add_data(attrib, input); });
TensorProto tp;
EXPECT_STATUS_OK(utils::ConstantNodeProtoToTensorProto(c, tp));
EXPECT_THAT(get_data(tp), ::testing::ContainerEq(input));
}
TEST(TensorProtoUtilsTest, ConstantTensorProto) {
TestConstantNodeConversion<float>(
"value_float", AttributeProto_AttributeType_FLOAT,
[](AttributeProto& attrib, const std::vector<float>& data) { attrib.set_f(data[0]); },
[](const TensorProto& tp) {
return std::vector<float>(tp.float_data().cbegin(), tp.float_data().cend());
},
1);
TestConstantNodeConversion<float>(
"value_floats", AttributeProto_AttributeType_FLOATS,
[](AttributeProto& attrib, const std::vector<float>& data) {
*attrib.mutable_floats() = {data.cbegin(), data.cend()};
},
[](const TensorProto& tp) {
return std::vector<float>(tp.float_data().cbegin(), tp.float_data().cend());
},
-1);
TestConstantNodeConversion<int64_t>(
"value_int", AttributeProto_AttributeType_INT,
[](AttributeProto& attrib, const std::vector<int64_t>& data) { attrib.set_i(data[0]); },
[](const TensorProto& tp) {
return std::vector<int64_t>(tp.int64_data().cbegin(), tp.int64_data().cend());
},
1);
TestConstantNodeConversion<int64_t>(
"value_ints", AttributeProto_AttributeType_INTS,
[](AttributeProto& attrib, const std::vector<int64_t>& data) {
*attrib.mutable_ints() = {data.cbegin(), data.cend()};
},
[](const TensorProto& tp) {
return std::vector<int64_t>(tp.int64_data().cbegin(), tp.int64_data().cend());
},
-1);
TestConstantNodeConversion<std::string>(
"value_string", AttributeProto_AttributeType_STRING,
[](AttributeProto& attrib, const std::vector<std::string>& data) { attrib.set_s(data[0]); },
[](const TensorProto& tp) {
return std::vector<std::string>(tp.string_data().cbegin(), tp.string_data().cend());
},
1);
TestConstantNodeConversion<std::string>(
"value_strings", AttributeProto_AttributeType_STRINGS,
[](AttributeProto& attrib, const std::vector<std::string>& data) {
// for (const auto& s : data)
*attrib.mutable_strings() = {data.cbegin(), data.cend()};
},
[](const TensorProto& tp) {
return std::vector<std::string>(tp.string_data().cbegin(), tp.string_data().cend());
},
-1);
// sparse_tensor is covered by SparseTensorConversionTests.TestConstantNodeConversion
}
} // namespace test
} // namespace onnxruntime