onnxruntime/onnxruntime/test/util/test_utils.cc
Guoyu Wang e4dc4e4d3c
[NNAPI QDQ] AddQDQAdd/Mul, update to NNAPI QDQ handling, update some test settings (#10483)
* Squashed commit of the following:

commit 12380491a9
Author: Guoyu Wang <wanggy@outlook.com>
Date:   Mon Feb 7 12:59:04 2022 -0800

    Add qdq mul support

commit 9cadda7f2c
Merge: 7a32847761 0f5d0a091a
Author: Guoyu Wang <wanggy@outlook.com>
Date:   Mon Feb 7 11:24:47 2022 -0800

    Merge remote-tracking branch 'origin/master' into gwang-msft/qdq_mul

commit 7a32847761
Author: Guoyu Wang <wanggy@outlook.com>
Date:   Mon Feb 7 00:41:30 2022 -0800

    move test case to util

commit c1a8f0d81e
Author: Guoyu Wang <wanggy@outlook.com>
Date:   Fri Feb 4 13:04:26 2022 -0800

    update input/output check

commit a6f0a0d504
Author: Guoyu Wang <wanggy@outlook.com>
Date:   Thu Feb 3 18:37:21 2022 -0800

    update quantized io check functions

commit 87f4d1dcfe
Merge: 7849f07109 97b8f6f394
Author: Guoyu Wang <wanggy@outlook.com>
Date:   Wed Feb 2 17:22:58 2022 -0800

    Merge remote-tracking branch 'origin/master' into gwang-msft/qdq_mul

commit 7849f07109
Author: Guoyu Wang <wanggy@outlook.com>
Date:   Wed Feb 2 17:22:55 2022 -0800

    minor update

commit 7196cdf419
Author: Guoyu Wang <wanggy@outlook.com>
Date:   Wed Feb 2 10:50:10 2022 -0800

    init change

commit 84c00772a1
Merge: a8c7dce22f 7318361645
Author: Guoyu Wang <wanggy@outlook.com>
Date:   Tue Feb 1 18:21:17 2022 -0800

    Merge remote-tracking branch 'origin/master' into gwang-msft/qdq_mul

commit a8c7dce22f
Merge: 55e536c182 ef7b4dc05c
Author: Guoyu Wang <wanggy@outlook.com>
Date:   Tue Feb 1 13:51:04 2022 -0800

    Merge remote-tracking branch 'origin/master' into gwang-msft/qdq_mul

commit 55e536c182
Author: Guoyu Wang <wanggy@outlook.com>
Date:   Tue Feb 1 11:44:34 2022 -0800

    address cr comments

commit d460f5b776
Author: Guoyu Wang <wanggy@outlook.com>
Date:   Tue Feb 1 00:33:54 2022 -0800

    fix android UT failure

commit 52146cf06f
Author: Guoyu Wang <wanggy@outlook.com>
Date:   Mon Jan 31 16:01:13 2022 -0800

    fix build break

commit ec6d07df8b
Author: Guoyu Wang <wanggy@outlook.com>
Date:   Mon Jan 31 15:41:52 2022 -0800

    minor update to UT

commit 8ec8490b4f
Author: Guoyu Wang <wanggy@outlook.com>
Date:   Mon Jan 31 15:01:30 2022 -0800

    Add NNAPI support of QDQ Resize

* Update qdq add/mul test case, fix build break

* Address CR comments

* Add QLinearMul support

* remove unused params

* Address CR comments
2022-02-08 20:44:15 -08:00

203 lines
8.4 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "test/util/include/test_utils.h"
#include "core/framework/ort_value.h"
#include "core/graph/onnx_protobuf.h"
#include "core/session/inference_session.h"
#include "core/framework/tensorprotoutils.h"
#include "test/util/include/asserts.h"
#include "test/util/include/test/test_environment.h"
#include "test/util/include/inference_session_wrapper.h"
#include "gmock/gmock.h"
namespace onnxruntime {
namespace test {
static void VerifyOutputs(const std::vector<std::string>& output_names,
const std::vector<OrtValue>& expected_fetches,
const std::vector<OrtValue>& fetches,
const EPVerificationParams& params) {
ASSERT_EQ(expected_fetches.size(), fetches.size());
for (size_t i = 0, end = expected_fetches.size(); i < end; ++i) {
auto& ltensor = expected_fetches[i].Get<Tensor>();
auto& rtensor = fetches[i].Get<Tensor>();
ASSERT_EQ(ltensor.Shape().GetDims(), rtensor.Shape().GetDims());
auto element_type = ltensor.GetElementType();
switch (element_type) {
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
EXPECT_THAT(ltensor.DataAsSpan<int32_t>(), ::testing::ContainerEq(rtensor.DataAsSpan<int32_t>()))
<< " mismatch for " << output_names[i];
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
EXPECT_THAT(ltensor.DataAsSpan<int64_t>(), ::testing::ContainerEq(rtensor.DataAsSpan<int64_t>()))
<< " mismatch for " << output_names[i];
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
EXPECT_THAT(ltensor.DataAsSpan<uint8_t>(), ::testing::ContainerEq(rtensor.DataAsSpan<uint8_t>()))
<< " mismatch for " << output_names[i];
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
EXPECT_THAT(ltensor.DataAsSpan<float>(),
::testing::Pointwise(::testing::FloatNear(params.fp32_abs_err), rtensor.DataAsSpan<float>()));
break;
}
default:
ORT_THROW("Unhandled data type. Please add 'case' statement for ", element_type);
}
}
}
int CountAssignedNodes(const Graph& current_graph, const std::string& ep_type) {
int count = 0;
for (const auto& node : current_graph.Nodes()) {
if (node.GetExecutionProviderType() == ep_type) {
++count;
}
if (node.ContainsSubgraph()) {
for (const auto& entry : node.GetSubgraphs()) {
count += CountAssignedNodes(*entry, ep_type);
}
}
}
return count;
}
void RunAndVerifyOutputsWithEP(const ORTCHAR_T* model_path, const char* log_id,
std::unique_ptr<IExecutionProvider> execution_provider,
const NameMLValMap& feeds,
const EPVerificationParams& params) {
// read raw data from model provided by the model_path
std::ifstream stream(model_path, std::ios::in | std::ios::binary);
std::string model_data((std::istreambuf_iterator<char>(stream)), std::istreambuf_iterator<char>());
RunAndVerifyOutputsWithEP(model_data, log_id, std::move(execution_provider), feeds, params);
}
void RunAndVerifyOutputsWithEP(const std::string& model_data, const char* log_id,
std::unique_ptr<IExecutionProvider> execution_provider,
const NameMLValMap& feeds,
const EPVerificationParams& params) {
SessionOptions so;
so.session_logid = log_id;
RunOptions run_options;
run_options.run_tag = so.session_logid;
//
// get expected output from CPU EP
//
InferenceSessionWrapper session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast<int>(model_data.size())));
ASSERT_STATUS_OK(session_object.Initialize());
const auto& graph = session_object.GetGraph();
const auto& outputs = graph.GetOutputs();
// fetch all outputs
std::vector<std::string> output_names;
output_names.reserve(outputs.size());
for (const auto* node_arg : outputs) {
if (node_arg->Exists()) {
output_names.push_back(node_arg->Name());
}
}
std::vector<OrtValue> expected_fetches;
ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &expected_fetches));
auto provider_type = execution_provider->Type(); // copy string so the std::move doesn't affect us
//
// get output with EP enabled
//
InferenceSessionWrapper session_object2{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object2.RegisterExecutionProvider(std::move(execution_provider)));
ASSERT_STATUS_OK(session_object2.Load(model_data.data(), static_cast<int>(model_data.size())));
ASSERT_STATUS_OK(session_object2.Initialize());
// make sure that some nodes are assigned to the EP, otherwise this test is pointless...
const auto& graph2 = session_object2.GetGraph();
auto ep_nodes = CountAssignedNodes(graph2, provider_type);
if (params.verify_entire_graph_use_ep) {
// Verify the entire graph is assigned to the EP
ASSERT_EQ(ep_nodes, graph2.NumberOfNodes()) << "Not all nodes were assigned to " << provider_type;
} else {
ASSERT_GT(ep_nodes, 0) << "No nodes were assigned to " << provider_type;
}
// Run with EP and verify the result
std::vector<OrtValue> fetches;
ASSERT_STATUS_OK(session_object2.Run(run_options, feeds, output_names, &fetches));
VerifyOutputs(output_names, expected_fetches, fetches, params);
}
#if !defined(DISABLE_SPARSE_TENSORS)
void SparseIndicesChecker(const ONNX_NAMESPACE::TensorProto& indices_proto, gsl::span<const int64_t> expected_indicies) {
using namespace ONNX_NAMESPACE;
Path model_path;
std::vector<uint8_t> unpack_buffer;
gsl::span<const int64_t> ind_span;
std::vector<int64_t> converted_indices;
TensorShape ind_shape(indices_proto.dims().data(), indices_proto.dims().size());
const auto elements = gsl::narrow<size_t>(ind_shape.Size());
const bool has_raw_data = indices_proto.has_raw_data();
switch (indices_proto.data_type()) {
case ONNX_NAMESPACE::TensorProto_DataType_INT64: {
if (has_raw_data) {
const auto& rd = indices_proto.raw_data();
ASSERT_EQ(rd.size(), elements * sizeof(int64_t));
ASSERT_STATUS_OK(utils::UnpackInitializerData(indices_proto, model_path, unpack_buffer));
ind_span = gsl::make_span(unpack_buffer).as_span<const int64_t>();
} else {
ind_span = gsl::make_span(indices_proto.int64_data().cbegin(), indices_proto.int64_data().cend());
}
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT32: {
if (has_raw_data) {
const auto& rd = indices_proto.raw_data();
ASSERT_EQ(rd.size(), elements * sizeof(int32_t));
ASSERT_STATUS_OK(utils::UnpackInitializerData(indices_proto, model_path, unpack_buffer));
auto int32_span = gsl::make_span(unpack_buffer).as_span<const int32_t>();
converted_indices.insert(converted_indices.cend(), int32_span.cbegin(), int32_span.cend());
} else {
converted_indices.insert(converted_indices.cend(), indices_proto.int32_data().cbegin(), indices_proto.int32_data().cend());
}
ind_span = gsl::make_span(converted_indices);
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT16: {
ASSERT_TRUE(has_raw_data);
const auto& rd = indices_proto.raw_data();
ASSERT_EQ(rd.size(), elements * sizeof(int16_t));
ASSERT_STATUS_OK(utils::UnpackInitializerData(indices_proto, model_path, unpack_buffer));
auto int16_span = gsl::make_span(unpack_buffer).as_span<const int16_t>();
converted_indices.insert(converted_indices.cend(), int16_span.cbegin(), int16_span.cend());
ind_span = gsl::make_span(converted_indices);
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
ASSERT_TRUE(has_raw_data);
const auto& rd = indices_proto.raw_data();
ASSERT_EQ(rd.size(), elements);
ASSERT_STATUS_OK(utils::UnpackInitializerData(indices_proto, model_path, unpack_buffer));
auto int8_span = gsl::make_span(unpack_buffer).as_span<const int8_t>();
converted_indices.insert(converted_indices.cend(), int8_span.cbegin(), int8_span.cend());
ind_span = gsl::make_span(converted_indices);
break;
}
default:
ASSERT_TRUE(false);
}
ASSERT_THAT(ind_span, testing::ContainerEq(expected_indicies));
}
#endif // DISABLE_SPARSE_TENSORS
} // namespace test
} // namespace onnxruntime