onnxruntime/winml/lib/Api.Core/ModelInfo.cpp
Paul McDaniel 5350abe19d
LearningModelSession is cleaned up to use the adapter, and parts of b… (#2382)
this is a big PR.    we are going to move it up to layer_dev , which is still a L3 so we are still safe to do work there agile.

we are going to move this into the L3 so that ryan can start doing intergration testing.   

we will pause for a full code review and integration test result prior to going into the L2.

>>>> raw comments from previous commits >>> 

* LearningModelSession is cleaned up to use the adapter, and parts of binding are.
* moved everything in the winmladapter
made it all nano-com using, WRL to construct objects in the ORT side.
base interfaces for everythign for winml to call
cleaned up a bunch of winml to use the base interfaces.
* more pieces
* GetData across the abi.
* renamed some namepsace
cleaned up OrtValue
cleaned up Tensor
cleaned up custom ops.
everything *but* learnignmodel should be clean
* make sure it's building.   winml.dll is still a monolith.
2019-11-14 17:44:07 -08:00

130 lines
3.6 KiB
C++

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include "pch.h"
#include "inc/ModelInfo.h"
#include "FeatureDescriptorFactory.h"
using namespace Windows::AI::MachineLearning;
static std::vector<const char*>
GetAllNodeOutputs(const onnx::ModelProto& model_proto) {
std::vector<const char*> nodes_outputs;
auto& graph = model_proto.graph();
auto& nodes = graph.node();
for (auto& node : nodes) {
for (auto& node_output : node.output()) {
nodes_outputs.push_back(node_output.c_str());
}
}
return nodes_outputs;
}
static std::vector<const char*>
GetInitializers(const onnx::ModelProto& model_proto) {
std::vector<const char*> initializers;
auto& graph = model_proto.graph();
auto& graph_initializers = graph.initializer();
for (auto& initializer : graph_initializers) {
initializers.push_back(initializer.name().c_str());
}
return initializers;
}
static std::vector<const onnx::ValueInfoProto*>
GetInputsWithoutInitializers(const onnx::ModelProto& model_proto) {
auto initializers = GetInitializers(model_proto);
std::vector<const onnx::ValueInfoProto*> inputs_without_initializers;
auto& graph = model_proto.graph();
auto& inputs = graph.input();
for (auto& input : inputs) {
if (input.has_name() && input.has_type()) {
auto found_it = std::find_if(
std::begin(initializers),
std::end(initializers),
[&](auto& initializer) {
return std::strcmp(initializer, input.name().c_str()) == 0;
});
auto is_initializer = found_it != std::end(initializers);
if (!is_initializer) {
inputs_without_initializers.push_back(&input);
}
}
}
return inputs_without_initializers;
}
static std::vector<const onnx::ValueInfoProto*>
GetOutputs(const onnx::ModelProto& model_proto) {
std::vector<const onnx::ValueInfoProto*> outputs_with_name;
auto& graph = model_proto.graph();
auto& outputs = graph.output();
for (auto& output : outputs) {
if (output.has_name() && output.has_type()) {
outputs_with_name.push_back(&output);
}
}
return outputs_with_name;
}
ModelInfo::ModelInfo(
const onnx::ModelProto* model_proto) {
Initialize(model_proto);
}
void ModelInfo::Initialize(
const onnx::ModelProto* model_proto) {
// metadata
for (auto& prop : model_proto->metadata_props()) {
model_metadata_[prop.key()] = prop.value();
}
WinML::FeatureDescriptorFactory builder(model_metadata_);
// Create inputs
auto inputs = GetInputsWithoutInitializers(*model_proto);
input_features_ = builder.CreateDescriptorsFromValueInfoProtos(inputs);
// Create outputs
auto outputs = ::GetOutputs(*model_proto);
output_features_ = builder.CreateDescriptorsFromValueInfoProtos(outputs);
// author
auto has_producer_name = model_proto->has_producer_name();
author_ = has_producer_name
? model_proto->producer_name()
: "";
// domain
auto has_domain = model_proto->has_domain();
domain_ = has_domain
? model_proto->domain()
: "";
// name
auto has_graph = model_proto->has_graph();
auto graph_has_name = model_proto->graph().has_name();
auto is_name_available = has_graph && graph_has_name;
name_ = is_name_available
? model_proto->graph().name()
: "";
// description
auto has_description = model_proto->has_doc_string();
description_ = has_description
? model_proto->doc_string()
: "";
// version
auto has_version = model_proto->has_model_version();
version_ = has_version
? model_proto->model_version()
: 0;
}