mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
* Task 23998197: add winml_lib_core into onnnxruntime.dll * PR feedback build break on perf_test
177 lines
4.9 KiB
C++
177 lines
4.9 KiB
C++
// Copyright (c) Microsoft Corporation.
|
|
// Licensed under the MIT License.
|
|
|
|
#include "pch.h"
|
|
|
|
#include "inc/ModelInfo.h"
|
|
|
|
#include <io.h>
|
|
#include <fcntl.h>
|
|
|
|
#include "FeatureDescriptorFactory.h"
|
|
#include "ZeroCopyInputStreamWrapper.h"
|
|
|
|
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
|
|
|
using namespace Windows::AI::MachineLearning;
|
|
|
|
static std::vector<const char*>
|
|
GetAllNodeOutputs(const onnx::ModelProto& model_proto) {
|
|
std::vector<const char*> nodes_outputs;
|
|
auto& graph = model_proto.graph();
|
|
auto& nodes = graph.node();
|
|
for (auto& node : nodes) {
|
|
for (auto& node_output : node.output()) {
|
|
nodes_outputs.push_back(node_output.c_str());
|
|
}
|
|
}
|
|
return nodes_outputs;
|
|
}
|
|
|
|
static std::vector<const char*>
|
|
GetInitializers(const onnx::ModelProto& model_proto) {
|
|
std::vector<const char*> initializers;
|
|
auto& graph = model_proto.graph();
|
|
auto& graph_initializers = graph.initializer();
|
|
for (auto& initializer : graph_initializers) {
|
|
initializers.push_back(initializer.name().c_str());
|
|
}
|
|
return initializers;
|
|
}
|
|
|
|
static std::vector<const onnx::ValueInfoProto*>
|
|
GetInputsWithoutInitializers(const onnx::ModelProto& model_proto) {
|
|
auto initializers = GetInitializers(model_proto);
|
|
|
|
std::vector<const onnx::ValueInfoProto*> inputs_without_initializers;
|
|
auto& graph = model_proto.graph();
|
|
auto& inputs = graph.input();
|
|
for (auto& input : inputs) {
|
|
if (input.has_name() && input.has_type()) {
|
|
auto found_it = std::find_if(
|
|
std::begin(initializers),
|
|
std::end(initializers),
|
|
[&](auto& initializer) {
|
|
return std::strcmp(initializer, input.name().c_str()) == 0;
|
|
});
|
|
|
|
auto is_initializer = found_it != std::end(initializers);
|
|
if (!is_initializer) {
|
|
inputs_without_initializers.push_back(&input);
|
|
}
|
|
}
|
|
}
|
|
return inputs_without_initializers;
|
|
}
|
|
|
|
static std::vector<const onnx::ValueInfoProto*>
|
|
GetOutputs(const onnx::ModelProto& model_proto) {
|
|
std::vector<const onnx::ValueInfoProto*> outputs_with_name;
|
|
auto& graph = model_proto.graph();
|
|
auto& outputs = graph.output();
|
|
for (auto& output : outputs) {
|
|
if (output.has_name() && output.has_type()) {
|
|
outputs_with_name.push_back(&output);
|
|
}
|
|
}
|
|
return outputs_with_name;
|
|
}
|
|
|
|
ModelInfo::ModelInfo(
|
|
const onnx::ModelProto* model_proto) {
|
|
Initialize(model_proto);
|
|
}
|
|
|
|
void ModelInfo::Initialize(
|
|
const onnx::ModelProto* model_proto) {
|
|
// metadata
|
|
for (auto& prop : model_proto->metadata_props()) {
|
|
model_metadata_[prop.key()] = prop.value();
|
|
}
|
|
|
|
WinML::FeatureDescriptorFactory builder(model_metadata_);
|
|
|
|
// Create inputs
|
|
auto inputs = GetInputsWithoutInitializers(*model_proto);
|
|
input_features_ = builder.CreateDescriptorsFromValueInfoProtos(inputs);
|
|
|
|
// Create outputs
|
|
auto outputs = ::GetOutputs(*model_proto);
|
|
output_features_ = builder.CreateDescriptorsFromValueInfoProtos(outputs);
|
|
|
|
// author
|
|
auto has_producer_name = model_proto->has_producer_name();
|
|
author_ = has_producer_name
|
|
? model_proto->producer_name()
|
|
: "";
|
|
|
|
// domain
|
|
auto has_domain = model_proto->has_domain();
|
|
domain_ = has_domain
|
|
? model_proto->domain()
|
|
: "";
|
|
|
|
// name
|
|
auto has_graph = model_proto->has_graph();
|
|
auto graph_has_name = model_proto->graph().has_name();
|
|
auto is_name_available = has_graph && graph_has_name;
|
|
name_ = is_name_available
|
|
? model_proto->graph().name()
|
|
: "";
|
|
|
|
// description
|
|
auto has_description = model_proto->has_doc_string();
|
|
description_ = has_description
|
|
? model_proto->doc_string()
|
|
: "";
|
|
|
|
// version
|
|
auto has_version = model_proto->has_model_version();
|
|
version_ = has_version
|
|
? model_proto->model_version()
|
|
: 0;
|
|
}
|
|
|
|
// factory methods for creating an ort model from a path
|
|
onnx::ModelProto*
|
|
WinML::CreateModelProto(
|
|
const char* path) {
|
|
int file_descriptor;
|
|
_sopen_s(
|
|
&file_descriptor,
|
|
path,
|
|
O_RDONLY | _O_SEQUENTIAL | _O_BINARY,
|
|
_SH_DENYWR,
|
|
_S_IREAD | _S_IWRITE);
|
|
|
|
THROW_HR_IF_MSG(
|
|
E_FAIL,
|
|
0 > file_descriptor,
|
|
"Failed"); //errno
|
|
|
|
auto stream = google::protobuf::io::FileInputStream(file_descriptor);
|
|
stream.SetCloseOnDelete(true);
|
|
|
|
auto model_proto = new onnx::ModelProto();
|
|
THROW_HR_IF_MSG(
|
|
E_INVALIDARG,
|
|
!model_proto->ParseFromZeroCopyStream(&stream) == false,
|
|
"The stream failed to parse.");
|
|
|
|
return model_proto;
|
|
}
|
|
|
|
// factory methods for creating an ort model from a stream
|
|
onnx::ModelProto*
|
|
WinML::CreateModelProto(
|
|
const wss::IRandomAccessStreamReference& stream_reference) {
|
|
ZeroCopyInputStreamWrapper wrapper(stream_reference);
|
|
|
|
auto model_proto = new onnx::ModelProto();
|
|
THROW_HR_IF_MSG(
|
|
E_INVALIDARG,
|
|
model_proto->ParseFromZeroCopyStream(&wrapper) == false,
|
|
"The stream failed to parse.");
|
|
|
|
return model_proto;
|
|
}
|