mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-27 22:45:57 +00:00
236 lines
9.7 KiB
C++
236 lines
9.7 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include "test/util/include/test_utils.h"
|
|
|
|
#include "core/common/narrow.h"
|
|
#include "core/common/span_utils.h"
|
|
#include "core/framework/ort_value.h"
|
|
#include "core/graph/onnx_protobuf.h"
|
|
#include "core/session/inference_session.h"
|
|
#include "core/framework/tensorprotoutils.h"
|
|
|
|
#include "test/util/include/asserts.h"
|
|
#include "test/util/include/test/test_environment.h"
|
|
#include "test/util/include/inference_session_wrapper.h"
|
|
|
|
#include "gmock/gmock.h"
|
|
|
|
namespace onnxruntime {
|
|
namespace test {
|
|
static void VerifyOutputs(const std::vector<std::string>& output_names,
|
|
const std::vector<OrtValue>& expected_fetches,
|
|
const std::vector<OrtValue>& fetches,
|
|
const EPVerificationParams& params) {
|
|
ASSERT_EQ(expected_fetches.size(), fetches.size());
|
|
|
|
for (size_t i = 0, end = expected_fetches.size(); i < end; ++i) {
|
|
auto& ltensor = expected_fetches[i].Get<Tensor>();
|
|
auto& rtensor = fetches[i].Get<Tensor>();
|
|
ASSERT_TRUE(SpanEq(ltensor.Shape().GetDims(), rtensor.Shape().GetDims()));
|
|
auto element_type = ltensor.GetElementType();
|
|
switch (element_type) {
|
|
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
|
|
EXPECT_TRUE(SpanEq(ltensor.DataAsSpan<int32_t>(), rtensor.DataAsSpan<int32_t>()))
|
|
<< " mismatch for " << output_names[i];
|
|
break;
|
|
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
|
|
EXPECT_TRUE(SpanEq(ltensor.DataAsSpan<int64_t>(), rtensor.DataAsSpan<int64_t>()))
|
|
<< " mismatch for " << output_names[i];
|
|
break;
|
|
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
|
EXPECT_TRUE(SpanEq(ltensor.DataAsSpan<uint8_t>(), rtensor.DataAsSpan<uint8_t>()))
|
|
<< " mismatch for " << output_names[i];
|
|
break;
|
|
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
|
EXPECT_TRUE(SpanEq(ltensor.DataAsSpan<int8_t>(), rtensor.DataAsSpan<int8_t>()))
|
|
<< " mismatch for " << output_names[i];
|
|
break;
|
|
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
|
|
EXPECT_THAT(ltensor.DataAsSpan<float>(),
|
|
::testing::Pointwise(::testing::FloatNear(params.fp32_abs_err), rtensor.DataAsSpan<float>()));
|
|
break;
|
|
}
|
|
default:
|
|
ORT_THROW("Unhandled data type. Please add 'case' statement for ", element_type);
|
|
}
|
|
}
|
|
}
|
|
|
|
int CountAssignedNodes(const Graph& current_graph, const std::string& ep_type) {
|
|
int count = 0;
|
|
|
|
for (const auto& node : current_graph.Nodes()) {
|
|
if (node.GetExecutionProviderType() == ep_type) {
|
|
++count;
|
|
}
|
|
|
|
if (node.ContainsSubgraph()) {
|
|
for (const auto& entry : node.GetSubgraphs()) {
|
|
count += CountAssignedNodes(*entry, ep_type);
|
|
}
|
|
}
|
|
}
|
|
|
|
return count;
|
|
}
|
|
|
|
void RunAndVerifyOutputsWithEP(const ORTCHAR_T* model_path, const char* log_id,
|
|
std::unique_ptr<IExecutionProvider> execution_provider,
|
|
const NameMLValMap& feeds,
|
|
const EPVerificationParams& params) {
|
|
// read raw data from model provided by the model_path
|
|
std::ifstream stream(model_path, std::ios::in | std::ios::binary);
|
|
std::string model_data((std::istreambuf_iterator<char>(stream)), std::istreambuf_iterator<char>());
|
|
RunAndVerifyOutputsWithEP(model_data, log_id, std::move(execution_provider), feeds, params);
|
|
}
|
|
|
|
void RunAndVerifyOutputsWithEP(const std::string& model_data, const char* log_id,
|
|
std::unique_ptr<IExecutionProvider> execution_provider,
|
|
const NameMLValMap& feeds,
|
|
const EPVerificationParams& params) {
|
|
SessionOptions so;
|
|
so.session_logid = log_id;
|
|
RunOptions run_options;
|
|
run_options.run_tag = so.session_logid;
|
|
|
|
//
|
|
// get expected output from CPU EP
|
|
//
|
|
InferenceSessionWrapper session_object{so, GetEnvironment()};
|
|
ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast<int>(model_data.size())));
|
|
ASSERT_STATUS_OK(session_object.Initialize());
|
|
|
|
const auto& graph = session_object.GetGraph();
|
|
const auto& outputs = graph.GetOutputs();
|
|
|
|
// fetch all outputs
|
|
std::vector<std::string> output_names;
|
|
output_names.reserve(outputs.size());
|
|
for (const auto* node_arg : outputs) {
|
|
if (node_arg->Exists()) {
|
|
output_names.push_back(node_arg->Name());
|
|
}
|
|
}
|
|
|
|
std::vector<OrtValue> expected_fetches;
|
|
ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &expected_fetches));
|
|
|
|
auto provider_type = execution_provider->Type(); // copy string so the std::move doesn't affect us
|
|
|
|
//
|
|
// get output with EP enabled
|
|
//
|
|
InferenceSessionWrapper session_object2{so, GetEnvironment()};
|
|
ASSERT_STATUS_OK(session_object2.RegisterExecutionProvider(std::move(execution_provider)));
|
|
ASSERT_STATUS_OK(session_object2.Load(model_data.data(), static_cast<int>(model_data.size())));
|
|
ASSERT_STATUS_OK(session_object2.Initialize());
|
|
|
|
// make sure that some nodes are assigned to the EP, otherwise this test is pointless...
|
|
const auto& graph2 = session_object2.GetGraph();
|
|
auto ep_nodes = CountAssignedNodes(graph2, provider_type);
|
|
if (params.ep_node_assignment == ExpectedEPNodeAssignment::All) {
|
|
// Verify the entire graph is assigned to the EP
|
|
ASSERT_EQ(ep_nodes, graph2.NumberOfNodes()) << "Not all nodes were assigned to " << provider_type;
|
|
} else if (params.ep_node_assignment == ExpectedEPNodeAssignment::None) {
|
|
// Check if expected failure path is correctly handled by ep. (only used in NNAPI EP QDQ model test case for now)
|
|
ASSERT_EQ(ep_nodes, 0) << "No nodes are supposed to be assigned to " << provider_type;
|
|
} else {
|
|
ASSERT_GT(ep_nodes, 0) << "No nodes were assigned to " << provider_type;
|
|
}
|
|
|
|
// Run with EP and verify the result
|
|
std::vector<OrtValue> fetches;
|
|
ASSERT_STATUS_OK(session_object2.Run(run_options, feeds, output_names, &fetches));
|
|
VerifyOutputs(output_names, expected_fetches, fetches, params);
|
|
|
|
if (params.graph_verifier) {
|
|
(*params.graph_verifier)(graph2);
|
|
}
|
|
}
|
|
|
|
void CheckShapeEquality(const ONNX_NAMESPACE::TensorShapeProto* shape1,
|
|
const ONNX_NAMESPACE::TensorShapeProto* shape2) {
|
|
EXPECT_NE(shape1, nullptr);
|
|
EXPECT_NE(shape2, nullptr);
|
|
EXPECT_EQ(shape1->dim_size(), shape2->dim_size()) << "Shapes do not have same rank";
|
|
auto min_dims = std::min(shape1->dim_size(), shape2->dim_size());
|
|
for (int i = 0; i < min_dims; ++i) {
|
|
auto dim1 = shape1->dim(i);
|
|
auto dim2 = shape2->dim(i);
|
|
EXPECT_EQ(dim1.has_dim_value(), dim2.has_dim_value());
|
|
if (dim1.has_dim_value()) {
|
|
EXPECT_EQ(dim1.dim_value(), dim2.dim_value());
|
|
}
|
|
EXPECT_EQ(dim1.has_dim_param(), dim2.has_dim_param());
|
|
if (dim1.has_dim_param()) {
|
|
EXPECT_EQ(dim1.dim_param(), dim2.dim_param());
|
|
}
|
|
}
|
|
}
|
|
|
|
#if !defined(DISABLE_SPARSE_TENSORS)
|
|
void SparseIndicesChecker(const ONNX_NAMESPACE::TensorProto& indices_proto, gsl::span<const int64_t> expected_indicies) {
|
|
using namespace ONNX_NAMESPACE;
|
|
Path model_path;
|
|
std::vector<uint8_t> unpack_buffer;
|
|
gsl::span<const int64_t> ind_span;
|
|
std::vector<int64_t> converted_indices;
|
|
TensorShape ind_shape(indices_proto.dims().data(), indices_proto.dims().size());
|
|
const auto elements = narrow<size_t>(ind_shape.Size());
|
|
const bool has_raw_data = indices_proto.has_raw_data();
|
|
switch (indices_proto.data_type()) {
|
|
case ONNX_NAMESPACE::TensorProto_DataType_INT64: {
|
|
if (has_raw_data) {
|
|
const auto& rd = indices_proto.raw_data();
|
|
ASSERT_EQ(rd.size(), elements * sizeof(int64_t));
|
|
ASSERT_STATUS_OK(utils::UnpackInitializerData(indices_proto, model_path, unpack_buffer));
|
|
ind_span = ReinterpretAsSpan<const int64_t>(gsl::make_span(unpack_buffer));
|
|
} else {
|
|
ind_span = gsl::make_span(indices_proto.int64_data().data(), indices_proto.int64_data_size());
|
|
}
|
|
break;
|
|
}
|
|
case ONNX_NAMESPACE::TensorProto_DataType_INT32: {
|
|
if (has_raw_data) {
|
|
const auto& rd = indices_proto.raw_data();
|
|
ASSERT_EQ(rd.size(), elements * sizeof(int32_t));
|
|
ASSERT_STATUS_OK(utils::UnpackInitializerData(indices_proto, model_path, unpack_buffer));
|
|
auto int32_span = ReinterpretAsSpan<const int32_t>(gsl::make_span(unpack_buffer));
|
|
converted_indices.insert(converted_indices.cend(), int32_span.begin(), int32_span.end());
|
|
} else {
|
|
converted_indices.insert(converted_indices.cend(), indices_proto.int32_data().cbegin(), indices_proto.int32_data().cend());
|
|
}
|
|
ind_span = gsl::make_span(converted_indices);
|
|
break;
|
|
}
|
|
case ONNX_NAMESPACE::TensorProto_DataType_INT16: {
|
|
ASSERT_TRUE(has_raw_data);
|
|
const auto& rd = indices_proto.raw_data();
|
|
ASSERT_EQ(rd.size(), elements * sizeof(int16_t));
|
|
ASSERT_STATUS_OK(utils::UnpackInitializerData(indices_proto, model_path, unpack_buffer));
|
|
auto int16_span = ReinterpretAsSpan<const int16_t>(gsl::make_span(unpack_buffer));
|
|
converted_indices.insert(converted_indices.cend(), int16_span.begin(), int16_span.end());
|
|
ind_span = gsl::make_span(converted_indices);
|
|
break;
|
|
}
|
|
case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
|
|
ASSERT_TRUE(has_raw_data);
|
|
const auto& rd = indices_proto.raw_data();
|
|
ASSERT_EQ(rd.size(), elements);
|
|
ASSERT_STATUS_OK(utils::UnpackInitializerData(indices_proto, model_path, unpack_buffer));
|
|
auto int8_span = ReinterpretAsSpan<const int8_t>(gsl::make_span(unpack_buffer));
|
|
converted_indices.insert(converted_indices.cend(), int8_span.begin(), int8_span.end());
|
|
ind_span = gsl::make_span(converted_indices);
|
|
break;
|
|
}
|
|
default:
|
|
ASSERT_TRUE(false);
|
|
}
|
|
ASSERT_TRUE(SpanEq(ind_span, expected_indicies));
|
|
}
|
|
|
|
#endif // DISABLE_SPARSE_TENSORS
|
|
|
|
} // namespace test
|
|
} // namespace onnxruntime
|