// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. #include "pch.h" #include "inc/ModelInfo.h" #include #include #include "FeatureDescriptorFactory.h" #include "ZeroCopyInputStreamWrapper.h" #include "google/protobuf/io/zero_copy_stream_impl.h" using namespace Windows::AI::MachineLearning; static std::vector GetAllNodeOutputs(const onnx::ModelProto& model_proto) { std::vector nodes_outputs; auto& graph = model_proto.graph(); auto& nodes = graph.node(); for (auto& node : nodes) { for (auto& node_output : node.output()) { nodes_outputs.push_back(node_output.c_str()); } } return nodes_outputs; } static std::vector GetInitializers(const onnx::ModelProto& model_proto) { std::vector initializers; auto& graph = model_proto.graph(); auto& graph_initializers = graph.initializer(); for (auto& initializer : graph_initializers) { initializers.push_back(initializer.name().c_str()); } return initializers; } static std::vector GetInputsWithoutInitializers(const onnx::ModelProto& model_proto) { auto initializers = GetInitializers(model_proto); std::vector inputs_without_initializers; auto& graph = model_proto.graph(); auto& inputs = graph.input(); for (auto& input : inputs) { if (input.has_name() && input.has_type()) { auto found_it = std::find_if( std::begin(initializers), std::end(initializers), [&](auto& initializer) { return std::strcmp(initializer, input.name().c_str()) == 0; }); auto is_initializer = found_it != std::end(initializers); if (!is_initializer) { inputs_without_initializers.push_back(&input); } } } return inputs_without_initializers; } static std::vector GetOutputs(const onnx::ModelProto& model_proto) { std::vector outputs_with_name; auto& graph = model_proto.graph(); auto& outputs = graph.output(); for (auto& output : outputs) { if (output.has_name() && output.has_type()) { outputs_with_name.push_back(&output); } } return outputs_with_name; } ModelInfo::ModelInfo( const onnx::ModelProto* model_proto) { Initialize(model_proto); } void ModelInfo::Initialize( const onnx::ModelProto* model_proto) { // metadata for (auto& prop : model_proto->metadata_props()) { model_metadata_[prop.key()] = prop.value(); } WinML::FeatureDescriptorFactory builder(model_metadata_); // Create inputs auto inputs = GetInputsWithoutInitializers(*model_proto); input_features_ = builder.CreateDescriptorsFromValueInfoProtos(inputs); // Create outputs auto outputs = ::GetOutputs(*model_proto); output_features_ = builder.CreateDescriptorsFromValueInfoProtos(outputs); // author auto has_producer_name = model_proto->has_producer_name(); author_ = has_producer_name ? model_proto->producer_name() : ""; // domain auto has_domain = model_proto->has_domain(); domain_ = has_domain ? model_proto->domain() : ""; // name auto has_graph = model_proto->has_graph(); auto graph_has_name = model_proto->graph().has_name(); auto is_name_available = has_graph && graph_has_name; name_ = is_name_available ? model_proto->graph().name() : ""; // description auto has_description = model_proto->has_doc_string(); description_ = has_description ? model_proto->doc_string() : ""; // version auto has_version = model_proto->has_model_version(); version_ = has_version ? model_proto->model_version() : 0; } // factory methods for creating an ort model from a path onnx::ModelProto* WinML::CreateModelProto( const char* path) { int file_descriptor; _sopen_s( &file_descriptor, path, O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE); THROW_HR_IF_MSG( E_FAIL, 0 > file_descriptor, "Failed"); //errno auto stream = google::protobuf::io::FileInputStream(file_descriptor); stream.SetCloseOnDelete(true); auto model_proto = new onnx::ModelProto(); THROW_HR_IF_MSG( E_INVALIDARG, !model_proto->ParseFromZeroCopyStream(&stream) == false, "The stream failed to parse."); return model_proto; } // factory methods for creating an ort model from a stream onnx::ModelProto* WinML::CreateModelProto( const wss::IRandomAccessStreamReference& stream_reference) { ZeroCopyInputStreamWrapper wrapper(stream_reference); auto model_proto = new onnx::ModelProto(); THROW_HR_IF_MSG( E_INVALIDARG, model_proto->ParseFromZeroCopyStream(&wrapper) == false, "The stream failed to parse."); return model_proto; }