diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index ca295fc5b2..cbe635e60b 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -59,28 +59,28 @@ namespace onnxruntime { namespace utils { // This macro doesn't work for Float16/bool/string tensors -#define DEFINE_UNPACK_TENSOR(T, Type, field_name, field_size) \ - template <> \ - Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, \ - /*out*/ T* p_data, size_t expected_size) { \ - if (nullptr == p_data) { \ - const size_t size = raw_data != nullptr ? raw_data_len : tensor.field_size(); \ - if (size == 0) return Status::OK(); \ - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \ - } \ - if (nullptr == p_data || Type != tensor.data_type()) { \ - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \ - } \ - if (raw_data != nullptr) { \ - return UnpackTensorWithRawData(raw_data, raw_data_len, expected_size, p_data); \ - } \ - if (static_cast(tensor.field_size()) != expected_size) \ - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "corrupted protobuf data: tensor shape size(", expected_size, \ - ") does not match the data size(", tensor.field_size(), ") in proto"); \ - auto& data = tensor.field_name(); \ - for (auto data_iter = data.cbegin(); data_iter != data.cend(); ++data_iter) \ - *p_data++ = *reinterpret_cast(data_iter); \ - return Status::OK(); \ +#define DEFINE_UNPACK_TENSOR(T, Type, field_name, field_size) \ + template <> \ + Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, \ + /*out*/ T* p_data, size_t expected_size) { \ + if (nullptr == p_data) { \ + const size_t size = raw_data != nullptr ? raw_data_len : tensor.field_size(); \ + if (size == 0) return Status::OK(); \ + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \ + } \ + if (nullptr == p_data || Type != tensor.data_type()) { \ + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \ + } \ + if (raw_data != nullptr) { \ + return UnpackTensorWithRawData(raw_data, raw_data_len, expected_size, p_data); \ + } \ + if (static_cast(tensor.field_size()) != expected_size) \ + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "corrupted protobuf data: tensor shape size(", \ + expected_size, ") does not match the data size(", tensor.field_size(), ") in proto"); \ + auto& data = tensor.field_name(); \ + for (auto data_iter = data.cbegin(); data_iter != data.cend(); ++data_iter) \ + *p_data++ = *reinterpret_cast(data_iter); \ + return Status::OK(); \ } // TODO: complex64 complex128 @@ -310,10 +310,10 @@ ORT_API(void, OrtUninitializeBuffer, _In_opt_ void* input, size_t input_len, enu } } -#define CASE_PROTO(X, Y) \ - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ - ORT_RETURN_IF_ERROR( \ - ::onnxruntime::utils::UnpackTensor(tensor_proto, raw_data, raw_data_len, (Y*)preallocated, static_cast(tensor_size))); \ +#define CASE_PROTO(X, Y) \ + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ + ORT_RETURN_IF_ERROR( \ + UnpackTensor(tensor_proto, raw_data, raw_data_len, (Y*)preallocated, static_cast(tensor_size))); \ break; class AutoDelete { @@ -421,7 +421,8 @@ Status TensorProtoToMLValue(const Env& env, const ORTCHAR_T* tensor_proto_path, int64_t tensor_size = 1; { for (auto i : tensor_proto.dims()) { - if (i < 0) return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "tensor can't contain negative dims"); + if (i < 0) + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "tensor can't contain negative dims"); tensor_size *= i; } } @@ -464,8 +465,8 @@ Status TensorProtoToMLValue(const Env& env, const ORTCHAR_T* tensor_proto_path, deleter.f = UnInitTensor; deleter.param = new UnInitializeParam{preallocated, preallocated_size, ele_type}; } - ORT_RETURN_IF_ERROR(::onnxruntime::utils::UnpackTensor(tensor_proto, raw_data, raw_data_len, - (std::string*)preallocated, static_cast(tensor_size))); + ORT_RETURN_IF_ERROR(UnpackTensor(tensor_proto, raw_data, raw_data_len, + (std::string*)preallocated, static_cast(tensor_size))); break; default: { std::ostringstream ostr; @@ -536,12 +537,238 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std: //tensor_proto.set_data_type(utils::GetTensorProtoType(tensor)); tensor_proto.set_data_type(tensor_proto_type.tensor_type().elem_type()); - tensor_proto.set_raw_data(tensor.DataRaw(), tensor.SizeInBytes()); return tensor_proto; } +common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& node, + ONNX_NAMESPACE::TensorProto& tensor) { + const AttributeProto& constant_attribute = node.attribute(0); + + switch (constant_attribute.type()) { + case AttributeProto_AttributeType_TENSOR: + tensor = constant_attribute.t(); + break; + case AttributeProto_AttributeType_FLOAT: + tensor.set_data_type(TensorProto_DataType_FLOAT); + tensor.add_float_data(constant_attribute.f()); + break; + case AttributeProto_AttributeType_FLOATS: + tensor.set_data_type(TensorProto_DataType_FLOAT); + *tensor.mutable_float_data() = constant_attribute.floats(); + break; + case AttributeProto_AttributeType_INT: + tensor.set_data_type(TensorProto_DataType_INT64); + tensor.add_int64_data(constant_attribute.i()); + break; + case AttributeProto_AttributeType_INTS: + tensor.set_data_type(TensorProto_DataType_INT64); + *tensor.mutable_int64_data() = constant_attribute.ints(); + break; + case AttributeProto_AttributeType_STRING: + tensor.set_data_type(TensorProto_DataType_STRING); + tensor.add_string_data(constant_attribute.s()); + break; + case AttributeProto_AttributeType_STRINGS: { + tensor.set_data_type(TensorProto_DataType_STRING); + *tensor.mutable_string_data() = constant_attribute.strings(); + break; + } + case AttributeProto_AttributeType_SPARSE_TENSOR: { + auto& s = constant_attribute.sparse_tensor(); + ORT_RETURN_IF_ERROR(SparseTensorProtoToDenseTensorProto(s, tensor)); + break; + } + default: + ORT_THROW("Unsupported attribute value type of ", constant_attribute.type(), + " in 'Constant' node '", node.name(), "'"); + } + + // set name last in case attribute type was tensor (would copy over name) + *(tensor.mutable_name()) = node.output(0); + + return Status::OK(); +} + +template +static Status CopySparseData(size_t n_sparse_elements, + const ONNX_NAMESPACE::TensorProto& indices, + gsl::span dims, + std::function copier) { + Status status = Status::OK(); + TensorShape indices_shape(indices.dims().data(), indices.dims().size()); + + auto indices_data = gsl::make_span(indices.int64_data().data(), static_cast(indices_shape.Size())); + + if (indices_shape.NumDimensions() == 1) { + // flattened indexes + for (size_t i = 0; i < n_sparse_elements; ++i) { + copier(i, static_cast(indices_data[i])); + } + } else if (indices_shape.NumDimensions() == 2) { + // entries in format {NNZ, rank} + size_t rank = static_cast(indices_shape[1]); + ORT_ENFORCE(rank == dims.size() && rank > 0); + const int64_t* cur_index = indices_data.data(); + std::vector multipliers; + multipliers.resize(rank); + + // calculate sum of inner dimension elements for each dimension. + // e.g. if shape {2,3,4}, the result should be {3*4, 4, 1} + multipliers[rank - 1] = 1; + for (int32_t r = static_cast(rank) - 2; r >= 0; --r) { + multipliers[r] = static_cast(dims[r + 1]) * multipliers[r + 1]; + } + + // calculate the offset for the entry + // e.g. if shape was {2,3,4} and entry was (1, 0, 2) the offset is 14 + // as there are 2 rows, each with 12 entries per row + for (size_t i = 0; i < n_sparse_elements; ++i) { + size_t idx = 0; + for (size_t j = 0; j < rank; ++j) { + idx += static_cast(cur_index[j]) * multipliers[j]; + } + + copier(i, idx); + cur_index += rank; + } + + ORT_ENFORCE(cur_index == &*indices_data.cend()); + } else { + status = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Invalid SparseTensor indices. Should be rank 0 or 1. Got:", + indices_shape); + } + + return status; +} + +common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseTensorProto& sparse, + ONNX_NAMESPACE::TensorProto& dense) { + Status status = Status::OK(); + + const auto& sparse_values = sparse.values(); + auto type = sparse_values.data_type(); + dense.set_data_type(type); + + SafeInt n_sparse_elements = 1; + for (auto dim : sparse_values.dims()) { + n_sparse_elements *= dim; + } + + SafeInt n_dense_elements = 1; + for (auto dim : sparse.dims()) { + n_dense_elements *= dim; + dense.add_dims(dim); + } + + const auto& indices = sparse.indices(); + auto dims = gsl::make_span(dense.dims().data(), dense.dims().size()); + + // need to read in sparse data first as it could be in a type specific field, in raw data, or in external data + size_t sparse_bytes; + ORT_RETURN_IF_ERROR(GetSizeInBytesFromTensorProto<0>(sparse_values, &sparse_bytes)); + + if (type != TensorProto_DataType_STRING) { + std::vector sparse_data_storage(sparse_bytes, 0); + void* sparse_data = sparse_data_storage.data(); + + size_t element_size = 0; + + // setup buffer for output + switch (type) { + case TensorProto_DataType_FLOAT: { + element_size = sizeof(float); + UnpackTensor(sparse_values, static_cast(sparse_data), n_sparse_elements); + break; + } + case TensorProto_DataType_INT64: { + element_size = sizeof(int64_t); + UnpackTensor(sparse_values, static_cast(sparse_data), n_sparse_elements); + break; + } + case TensorProto_DataType_INT32: { + element_size = sizeof(int32_t); + UnpackTensor(sparse_values, static_cast(sparse_data), n_sparse_elements); + break; + } + case TensorProto_DataType_DOUBLE: { + element_size = sizeof(double); + UnpackTensor(sparse_values, static_cast(sparse_data), n_sparse_elements); + break; + } + case TensorProto_DataType_UINT32: { + element_size = sizeof(uint32_t); + UnpackTensor(sparse_values, static_cast(sparse_data), n_sparse_elements); + break; + } + case TensorProto_DataType_UINT64: { + element_size = sizeof(uint64_t); + UnpackTensor(sparse_values, static_cast(sparse_data), n_sparse_elements); + break; + } + default: + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported sparse tensor data type of ", type); + } + + // by putting the data into a std::string we can avoid a copy as set_raw_data can do a std::move + // into the TensorProto. however to actually write to the buffer we have created in the std::string we need + // this somewhat dirty hack to get a mutable pointer. we could alternatively use &dense_data_storage.front() + // but using const_cast makes it more obvious we're doing something ugly. + std::string dense_data_storage(n_dense_elements * element_size, 0); + void* dense_data = const_cast(dense_data_storage.data()); + + switch (element_size) { + case 4: { + auto dense_data_span = gsl::make_span(static_cast(dense_data), n_dense_elements); + status = CopySparseData( + n_sparse_elements, + indices, dims, + [sparse_data, dense_data_span](size_t from_idx, size_t to_idx) { + dense_data_span[to_idx] = static_cast(sparse_data)[from_idx]; + }); + + break; + } + case 8: { + auto dense_data_span = gsl::make_span(static_cast(dense_data), n_dense_elements); + status = CopySparseData( + n_sparse_elements, + indices, dims, + [sparse_data, dense_data_span](size_t from_idx, size_t to_idx) { + dense_data_span[to_idx] = static_cast(sparse_data)[from_idx]; + }); + + break; + } + } + + dense.set_raw_data(std::move(dense_data_storage)); + + } else { + // strings need to be handled differently as they can't use raw data (as per ONNX rules) + std::vector sparse_data(n_sparse_elements); + UnpackTensor(sparse_values, sparse_data.data(), n_sparse_elements); + + // RepeatedPtrField doesn't have a Resize method so manually add elements + auto dense_strings = dense.mutable_string_data(); + dense_strings->Reserve(n_dense_elements); + for (int64_t j = 0; j < n_dense_elements; ++j) { + dense_strings->Add(""); + } + + status = CopySparseData( + n_sparse_elements, + indices, dims, + [&sparse_values, &dense_strings](size_t from_idx, size_t to_idx) { + const std::string& input = sparse_values.string_data()[SafeInt(from_idx)]; + *dense_strings->Mutable(SafeInt(to_idx)) = input; + }); + } + + return status; +} + template common::Status GetSizeInBytesFromTensorProto<256>(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out); template common::Status GetSizeInBytesFromTensorProto<0>(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out); diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 18c6f544a6..7e6fd72506 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -60,6 +60,14 @@ template Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ T* p_data, size_t expected_size); +// Convert the NodeProto from a Constant node into a TensorProto that can be used as an initializer +common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& node, + ONNX_NAMESPACE::TensorProto& tensor); + +// Convert a SparseTensorProto to a dense TensorProto +common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseTensorProto& sparse, + ONNX_NAMESPACE::TensorProto& dense); + inline bool HasDimValue(const ONNX_NAMESPACE::TensorShapeProto_Dimension& dim) { return dim.value_case() == ONNX_NAMESPACE::TensorShapeProto_Dimension::kDimValue; } @@ -199,5 +207,13 @@ inline bool HasName(const ONNX_NAMESPACE::NodeProto& node_proto) { return node_proto.has_name(); } +// UnpackTensor from either raw data or the type specific data field. +template +Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, /*out*/ T* p_data, size_t expected_size) { + return HasRawData(tensor) + ? UnpackTensor(tensor, tensor.raw_data().data(), tensor.raw_data().size(), p_data, expected_size) + : UnpackTensor(tensor, nullptr, 0, p_data, expected_size); +} + } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 20cecbdfb7..1327aee8ff 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -703,7 +703,7 @@ Graph::Graph(const Model& owning_model, const logging::Logger& logger, const std::unordered_map& model_functions) : Graph(owning_model, graph_proto, domain_to_version, ir_version, schema_registry, nullptr, nullptr, logger, - model_functions) {} + model_functions) {} Graph::Graph(const Model& owning_model, GraphProto* graph_proto, const std::unordered_map& domain_to_version, Version ir_version, @@ -731,16 +731,9 @@ Graph::Graph(const Model& owning_model, continue; } - // Copy constant nodes _value to name_to_initial_tensor_ const gsl::not_null tensor{graph_proto_->add_initializer()}; - const AttributeProto& constant_attribute = node.attribute(0); - // TODO: Add support for parsing 'sparse_value' attribute from a 'Constant' node - // Discussion surrounding handling the SparseTensorProto must be had. - // An easy way is to implement a method that converts a SparseTensorproto into a TensorProto - // to use the same downstream flow, but that is going to impact peak memory usage and probably a smarter way is required. - ORT_ENFORCE(constant_attribute.has_t(), "Only 'value' attribute is supported within a 'Constant' node in ORT"); - *tensor = constant_attribute.t(); - *(tensor->mutable_name()) = node.output(0); + auto status = utils::ConstantNodeProtoToTensorProto(node, *tensor); + ORT_ENFORCE(status.IsOK(), status.ToString()); } // Remove constant nodes as they're replaced with initializers above. diff --git a/onnxruntime/test/framework/sparse_kernels_test.cc b/onnxruntime/test/framework/sparse_kernels_test.cc index e016149cef..df70d26171 100644 --- a/onnxruntime/test/framework/sparse_kernels_test.cc +++ b/onnxruntime/test/framework/sparse_kernels_test.cc @@ -8,13 +8,16 @@ #include "core/graph/constants.h" #include "core/framework/op_kernel.h" #include "core/framework/sparse_tensor.h" +#include "core/framework/tensorprotoutils.h" #include "core/graph/model.h" #include "core/session/inference_session.h" -#include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "test_utils.h" +#include "gtest/gtest.h" +#include "gmock/gmock.h" + using namespace ONNX_NAMESPACE; using namespace onnxruntime::common; @@ -490,5 +493,179 @@ TEST_F(SparseTensorTests, Test2) { RunTest(); } +template +static std::vector CreateValues() { + return {1, 2, 3, 4}; +} + +template <> +std::vector CreateValues() { + return {"one", "two", "three", "four"}; +} + +template +static NodeProto CreateConstantNode(bool indices_1D, + std::function& values, TensorProto& tp)> inserter, + std::vector& expected_data) { + NodeProto constant_node; + constant_node.set_op_type("Constant"); + constant_node.add_output("dense_tensor_output"); + + std::vector values = CreateValues(); + std::vector indices; + std::vector shape{2, 3, 2}; + + AttributeProto& attrib = *constant_node.mutable_attribute()->Add(); + attrib.set_name("sparse_value"); + attrib.set_type(AttributeProto_AttributeType_SPARSE_TENSOR); + + SparseTensorProto& stp = *attrib.mutable_sparse_tensor(); + TensorProto& indices_tp = *stp.mutable_indices(); + + stp.mutable_dims()->Add(shape.cbegin(), shape.cend()); + for (auto dim : stp.dims()) + std::cout << dim; + + if (indices_1D) { + indices = {2, 5, 6, 10}; + indices_tp.add_dims(indices.size()); + } else { + // indices are shape {NNZ, rank} so convert flattened values of 2, 5, 6 and 10 to rank 3 values + indices_tp.add_dims(values.size()); + indices_tp.add_dims(shape.size()); + indices = { + 0, 1, 0, + 0, 2, 1, + 1, 0, 0, + 1, 2, 0}; + } + + indices_tp.mutable_int64_data()->Add(indices.cbegin(), indices.cend()); + + expected_data.resize(2 * 3 * 2); + expected_data[2] = values[0]; + expected_data[5] = values[1]; + expected_data[6] = values[2]; + expected_data[10] = values[3]; + + stp.mutable_values()->add_dims(values.size()); + inserter(values, *stp.mutable_values()); + + return constant_node; +} + +template +static void TestConversion(bool use_1D_indices, + std::function& values, TensorProto& tp)> inserter, + std::function expected, const TensorProto& actual)> checker) { + std::vector expected; + auto node = CreateConstantNode(use_1D_indices, inserter, expected); + + TensorProto dense; + utils::ConstantNodeProtoToTensorProto(node, dense); + + gsl::span expected_span = gsl::make_span(expected.data(), expected.size()); + checker(expected_span, dense); +} + +template +static void TestConversion( + std::function& values, TensorProto& tp)> inserter, + std::function expected, const TensorProto& actual)> checker) { + TestConversion(true, inserter, checker); + TestConversion(false, inserter, checker); +} + +template +static void RawDataWriter(const std::vector& values, TensorProto& tp, TensorProto_DataType datatype) { + tp.set_data_type(datatype); + tp.set_raw_data(values.data(), values.size() * sizeof(T)); +} + +template +static void RawDataChecker(gsl::span expected, const TensorProto& actual) { + int64_t actual_size = 1; + for (const auto dim : actual.dims()) { + actual_size *= dim; + } + + const T* raw_data = reinterpret_cast(actual.raw_data().data()); + auto actual_span = gsl::make_span(raw_data, actual_size); + + EXPECT_THAT(actual_span, testing::ContainerEq(expected)); +} + +TEST(SparseTensorConversionTests, TestConstantNodeConversion) { + TestConversion( + [](const std::vector& values, TensorProto& tp) { + tp.set_data_type(TensorProto_DataType_FLOAT); + tp.mutable_float_data()->Add(values.cbegin(), values.cend()); + }, + RawDataChecker); + + TestConversion( + [](const std::vector& values, TensorProto& tp) { + tp.set_data_type(TensorProto_DataType_INT32); + tp.mutable_int32_data()->Add(values.cbegin(), values.cend()); + }, + RawDataChecker); + + TestConversion( + [](const std::vector& values, TensorProto& tp) { + tp.set_data_type(TensorProto_DataType_INT64); + tp.mutable_int64_data()->Add(values.cbegin(), values.cend()); + }, + RawDataChecker); + + TestConversion( + [](const std::vector& values, TensorProto& tp) { + tp.set_data_type(TensorProto_DataType_DOUBLE); + tp.mutable_double_data()->Add(values.cbegin(), values.cend()); + }, + RawDataChecker); + + TestConversion( + [](const std::vector& values, TensorProto& tp) { + tp.set_data_type(TensorProto_DataType_UINT32); + tp.mutable_uint64_data()->Add(values.cbegin(), values.cend()); // stored in uint64_data despite being uint32_t + }, + RawDataChecker); + + TestConversion( + [](const std::vector& values, TensorProto& tp) { + tp.set_data_type(TensorProto_DataType_UINT64); + tp.mutable_uint64_data()->Add(values.cbegin(), values.cend()); + }, + RawDataChecker); + + // test a couple of types with values in raw data field + TestConversion( + [](const std::vector& values, TensorProto& tp) { + RawDataWriter(values, tp, TensorProto_DataType_FLOAT); + }, + RawDataChecker); + + TestConversion( + [](const std::vector& values, TensorProto& tp) { + RawDataWriter(values, tp, TensorProto_DataType_INT64); + }, + RawDataChecker); + + // strings can't use raw data, and string_data is a RepeatedPtrField (vs. RepeatedField for simple types) + // so has to be handled differently + TestConversion( + [](const std::vector& values, TensorProto& tp) { + tp.set_data_type(TensorProto_DataType_STRING); + for (auto cur = values.cbegin(), end = values.cend(); cur < end; ++cur) { + tp.mutable_string_data()->Add(std::string(*cur)); + } + }, + [](gsl::span expected, const TensorProto& actual) { + const auto& actual_strings = actual.string_data(); + for (int64_t i = 0, end = expected.size(); i < end; ++i) { + EXPECT_EQ(actual_strings[static_cast(i)], expected[i]); + } + }); +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/framework/tensorutils_test.cc b/onnxruntime/test/framework/tensorutils_test.cc index 43ca53bd9b..633b6159a1 100644 --- a/onnxruntime/test/framework/tensorutils_test.cc +++ b/onnxruntime/test/framework/tensorutils_test.cc @@ -3,21 +3,16 @@ #include "core/framework/tensorprotoutils.h" #include "core/graph/onnx_protobuf.h" +#include "test/util/include/asserts.h" + #include "gtest/gtest.h" +#include "gmock/gmock.h" using namespace ::onnxruntime::utils; using namespace ONNX_NAMESPACE; namespace onnxruntime { namespace test { -namespace { -template -Status UnpackTensorWrapper(const ONNX_NAMESPACE::TensorProto& tensor, /*out*/ T* p_data, int64_t expected_size) { - if (tensor.has_raw_data()) - return UnpackTensor(tensor, tensor.raw_data().data(), tensor.raw_data().size(), p_data, expected_size); - return UnpackTensor(tensor, nullptr, 0, p_data, expected_size); -} -} // namespace //T must be float for double, and it must match with the 'type' argument template @@ -32,7 +27,7 @@ void test_unpack_float_tensor(TensorProto_DataType type) { } float_tensor_proto.set_raw_data(rawdata, len); T float_data2[4]; - auto status = UnpackTensorWrapper(float_tensor_proto, float_data2, 4); + auto status = UnpackTensor(float_tensor_proto, float_data2, 4); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); EXPECT_EQ(1.1f, float_data2[0]); EXPECT_EQ(2.2f, float_data2[1]); @@ -46,12 +41,12 @@ TEST(TensorParseTest, TensorUtilsTest) { bool_tensor_proto.add_int32_data(1); bool bool_data[1]; - auto status = UnpackTensorWrapper(bool_tensor_proto, bool_data, 1); + auto status = UnpackTensor(bool_tensor_proto, bool_data, 1); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); EXPECT_TRUE(bool_data[0]); float float_data[1]; - status = UnpackTensorWrapper(bool_tensor_proto, float_data, 1); + status = UnpackTensor(bool_tensor_proto, float_data, 1); EXPECT_FALSE(status.IsOK()); test_unpack_float_tensor(TensorProto_DataType_FLOAT); @@ -63,13 +58,121 @@ TEST(TensorParseTest, TensorUtilsTest) { string_tensor_proto.add_string_data("b"); std::string string_data[2]; - status = UnpackTensorWrapper(string_tensor_proto, string_data, 2); + status = UnpackTensor(string_tensor_proto, string_data, 2); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); EXPECT_EQ("a", string_data[0]); EXPECT_EQ("b", string_data[1]); - status = UnpackTensorWrapper(bool_tensor_proto, string_data, 2); + status = UnpackTensor(bool_tensor_proto, string_data, 2); EXPECT_FALSE(status.IsOK()); } + +template +static std::vector CreateValues() { + return {1, 2, 3, 4}; +} + +template <> +std::vector CreateValues() { + return {"one", "two", "three", "four"}; +} + +template +static NodeProto CreateConstantNode(const std::string& attrib_name, AttributeProto_AttributeType type, + std::function add_data) { + NodeProto constant_node; + constant_node.set_op_type("Constant"); + constant_node.add_output("Constant_output"); + + AttributeProto& attrib = *constant_node.mutable_attribute()->Add(); + + attrib.set_name(attrib_name); + attrib.set_type(type); + add_data(attrib); + + return constant_node; +} + +template +static void TestConstantNodeConversion(const std::string& attrib_name, + AttributeProto_AttributeType type, + std::function& data)> add_data, + std::function(const TensorProto&)> get_data, + int64_t num_elements) { + auto input = CreateValues(); + if (num_elements == -1) { + num_elements = static_cast(input.size()); + } else { + input.resize(num_elements); + } + + auto c = CreateConstantNode( + attrib_name, type, + [&input, &add_data](AttributeProto& attrib) { add_data(attrib, input); }); + + TensorProto tp; + EXPECT_STATUS_OK(utils::ConstantNodeProtoToTensorProto(c, tp)); + + EXPECT_THAT(get_data(tp), ::testing::ContainerEq(input)); +} + +TEST(TensorProtoUtilsTest, ConstantTensorProto) { + TestConstantNodeConversion( + "value_float", AttributeProto_AttributeType_FLOAT, + [](AttributeProto& attrib, const std::vector& data) { attrib.set_f(data[0]); }, + [](const TensorProto& tp) { + return std::vector(tp.float_data().cbegin(), tp.float_data().cend()); + }, + 1); + + TestConstantNodeConversion( + "value_floats", AttributeProto_AttributeType_FLOATS, + [](AttributeProto& attrib, const std::vector& data) { + *attrib.mutable_floats() = {data.cbegin(), data.cend()}; + }, + [](const TensorProto& tp) { + return std::vector(tp.float_data().cbegin(), tp.float_data().cend()); + }, + -1); + + TestConstantNodeConversion( + "value_int", AttributeProto_AttributeType_INT, + [](AttributeProto& attrib, const std::vector& data) { attrib.set_i(data[0]); }, + [](const TensorProto& tp) { + return std::vector(tp.int64_data().cbegin(), tp.int64_data().cend()); + }, + 1); + + TestConstantNodeConversion( + "value_ints", AttributeProto_AttributeType_INTS, + [](AttributeProto& attrib, const std::vector& data) { + *attrib.mutable_ints() = {data.cbegin(), data.cend()}; + }, + [](const TensorProto& tp) { + return std::vector(tp.int64_data().cbegin(), tp.int64_data().cend()); + }, + -1); + + TestConstantNodeConversion( + "value_string", AttributeProto_AttributeType_STRING, + [](AttributeProto& attrib, const std::vector& data) { attrib.set_s(data[0]); }, + [](const TensorProto& tp) { + return std::vector(tp.string_data().cbegin(), tp.string_data().cend()); + }, + 1); + + TestConstantNodeConversion( + "value_strings", AttributeProto_AttributeType_STRINGS, + [](AttributeProto& attrib, const std::vector& data) { + // for (const auto& s : data) + *attrib.mutable_strings() = {data.cbegin(), data.cend()}; + }, + [](const TensorProto& tp) { + return std::vector(tp.string_data().cbegin(), tp.string_data().cend()); + }, + -1); + + // sparse_tensor is covered by SparseTensorConversionTests.TestConstantNodeConversion +} } // namespace test } // namespace onnxruntime