mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
model moved over.
everything builds clean. step !
This commit is contained in:
parent
00cee34ec0
commit
f07fdf96b4
7 changed files with 251 additions and 215 deletions
|
|
@ -471,24 +471,24 @@ if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
|
|||
endif("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
|
||||
|
||||
# Link libraries
|
||||
target_link_libraries(winml_dll PRIVATE libprotobuf)
|
||||
target_link_libraries(winml_dll PRIVATE onnx)
|
||||
target_link_libraries(winml_dll PRIVATE onnxruntime_common)
|
||||
target_link_libraries(winml_dll PRIVATE onnxruntime_graph)
|
||||
target_link_libraries(winml_dll PRIVATE onnxruntime_framework)
|
||||
target_link_libraries(winml_dll PRIVATE onnxruntime_mlas)
|
||||
target_link_libraries(winml_dll PRIVATE onnxruntime_optimizer)
|
||||
target_link_libraries(winml_dll PRIVATE onnxruntime_providers)
|
||||
target_link_libraries(winml_dll PRIVATE onnxruntime_providers_dml)
|
||||
target_link_libraries(winml_dll PRIVATE onnxruntime_session)
|
||||
target_link_libraries(winml_dll PRIVATE onnxruntime_util)
|
||||
target_link_libraries(winml_dll PRIVATE onnx_proto)
|
||||
#target_link_libraries(winml_dll PRIVATE libprotobuf)
|
||||
#target_link_libraries(winml_dll PRIVATE onnx)
|
||||
#target_link_libraries(winml_dll PRIVATE onnxruntime_common)
|
||||
#target_link_libraries(winml_dll PRIVATE onnxruntime_graph)
|
||||
#target_link_libraries(winml_dll PRIVATE onnxruntime_framework)
|
||||
#target_link_libraries(winml_dll PRIVATE onnxruntime_mlas)
|
||||
#target_link_libraries(winml_dll PRIVATE onnxruntime_optimizer)
|
||||
#target_link_libraries(winml_dll PRIVATE onnxruntime_providers)
|
||||
#target_link_libraries(winml_dll PRIVATE onnxruntime_providers_dml)
|
||||
#target_link_libraries(winml_dll PRIVATE onnxruntime_session)
|
||||
#target_link_libraries(winml_dll PRIVATE onnxruntime_util)
|
||||
#target_link_libraries(winml_dll PRIVATE onnx_proto)
|
||||
target_link_libraries(winml_dll PRIVATE onnxruntime)
|
||||
target_link_libraries(winml_dll PRIVATE re2)
|
||||
target_link_libraries(winml_dll PRIVATE wil)
|
||||
target_link_libraries(winml_dll PRIVATE windowsapp.lib)
|
||||
target_link_libraries(winml_dll PRIVATE winml_lib_api)
|
||||
target_link_libraries(winml_dll PRIVATE winml_lib_core)
|
||||
#target_link_libraries(winml_dll PRIVATE winml_lib_core)
|
||||
target_link_libraries(winml_dll PRIVATE winml_lib_image)
|
||||
target_link_libraries(winml_dll PRIVATE winml_lib_telemetry)
|
||||
target_link_libraries(winml_dll PRIVATE ${DBGHELP})
|
||||
|
|
|
|||
|
|
@ -10,121 +10,5 @@
|
|||
|
||||
using namespace Windows::AI::MachineLearning;
|
||||
|
||||
static std::vector<const char*>
|
||||
GetAllNodeOutputs(const onnx::ModelProto& model_proto) {
|
||||
std::vector<const char*> nodes_outputs;
|
||||
auto& graph = model_proto.graph();
|
||||
auto& nodes = graph.node();
|
||||
for (auto& node : nodes) {
|
||||
for (auto& node_output : node.output()) {
|
||||
nodes_outputs.push_back(node_output.c_str());
|
||||
}
|
||||
}
|
||||
return nodes_outputs;
|
||||
}
|
||||
|
||||
static std::vector<const char*>
|
||||
GetInitializers(const onnx::ModelProto& model_proto) {
|
||||
std::vector<const char*> initializers;
|
||||
auto& graph = model_proto.graph();
|
||||
auto& graph_initializers = graph.initializer();
|
||||
for (auto& initializer : graph_initializers) {
|
||||
initializers.push_back(initializer.name().c_str());
|
||||
}
|
||||
return initializers;
|
||||
}
|
||||
|
||||
static std::vector<const onnx::ValueInfoProto*>
|
||||
GetInputsWithoutInitializers(const onnx::ModelProto& model_proto) {
|
||||
auto initializers = GetInitializers(model_proto);
|
||||
|
||||
std::vector<const onnx::ValueInfoProto*> inputs_without_initializers;
|
||||
auto& graph = model_proto.graph();
|
||||
auto& inputs = graph.input();
|
||||
for (auto& input : inputs) {
|
||||
if (input.has_name() && input.has_type()) {
|
||||
auto found_it = std::find_if(
|
||||
std::begin(initializers),
|
||||
std::end(initializers),
|
||||
[&](auto& initializer) {
|
||||
return std::strcmp(initializer, input.name().c_str()) == 0;
|
||||
});
|
||||
|
||||
auto is_initializer = found_it != std::end(initializers);
|
||||
if (!is_initializer) {
|
||||
inputs_without_initializers.push_back(&input);
|
||||
}
|
||||
}
|
||||
}
|
||||
return inputs_without_initializers;
|
||||
}
|
||||
|
||||
static std::vector<const onnx::ValueInfoProto*>
|
||||
GetOutputs(const onnx::ModelProto& model_proto) {
|
||||
std::vector<const onnx::ValueInfoProto*> outputs_with_name;
|
||||
auto& graph = model_proto.graph();
|
||||
auto& outputs = graph.output();
|
||||
for (auto& output : outputs) {
|
||||
if (output.has_name() && output.has_type()) {
|
||||
outputs_with_name.push_back(&output);
|
||||
}
|
||||
}
|
||||
return outputs_with_name;
|
||||
}
|
||||
|
||||
ModelInfo::ModelInfo(
|
||||
const onnx::ModelProto* model_proto) {
|
||||
Initialize(model_proto);
|
||||
}
|
||||
|
||||
void ModelInfo::Initialize(
|
||||
const onnx::ModelProto* model_proto) {
|
||||
// metadata
|
||||
for (auto& prop : model_proto->metadata_props()) {
|
||||
model_metadata_[prop.key()] = prop.value();
|
||||
}
|
||||
|
||||
WinML::FeatureDescriptorFactory builder(model_metadata_);
|
||||
|
||||
// Create inputs
|
||||
auto inputs = GetInputsWithoutInitializers(*model_proto);
|
||||
input_features_ = builder.CreateDescriptorsFromValueInfoProtos(inputs);
|
||||
|
||||
// Create outputs
|
||||
auto outputs = ::GetOutputs(*model_proto);
|
||||
output_features_ = builder.CreateDescriptorsFromValueInfoProtos(outputs);
|
||||
|
||||
// author
|
||||
auto has_producer_name = model_proto->has_producer_name();
|
||||
author_ = has_producer_name
|
||||
? model_proto->producer_name()
|
||||
: "";
|
||||
|
||||
// domain
|
||||
auto has_domain = model_proto->has_domain();
|
||||
domain_ = has_domain
|
||||
? model_proto->domain()
|
||||
: "";
|
||||
|
||||
// name
|
||||
auto has_graph = model_proto->has_graph();
|
||||
auto graph_has_name = model_proto->graph().has_name();
|
||||
auto is_name_available = has_graph && graph_has_name;
|
||||
name_ = is_name_available
|
||||
? model_proto->graph().name()
|
||||
: "";
|
||||
|
||||
// description
|
||||
auto has_description = model_proto->has_doc_string();
|
||||
description_ = has_description
|
||||
? model_proto->doc_string()
|
||||
: "";
|
||||
|
||||
// version
|
||||
auto has_version = model_proto->has_model_version();
|
||||
version_ = has_version
|
||||
? model_proto->model_version()
|
||||
: 0;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,10 +4,12 @@
|
|||
#include "pch.h"
|
||||
#include "inc/WinMLAdapter.h"
|
||||
#include "inc/CustomRegistryHelper.h"
|
||||
#include "PheonixSingleton.h"
|
||||
#include "inc/LotusEnvironment.h"
|
||||
#include "inc/AbiCustomRegistryImpl.h"
|
||||
#include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h"
|
||||
#include "core/providers/dml/GraphTransformers/GraphTransformerHelpers.h"
|
||||
#include "core/providers/dml/OperatorAuthorHelper/SchemaInferenceOverrider.h"
|
||||
|
||||
#include "LearningModelDevice.h"
|
||||
#include "TensorFeatureDescriptor.h"
|
||||
|
|
@ -25,6 +27,8 @@
|
|||
#include "ZeroCopyInputStreamWrapper.h"
|
||||
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
||||
|
||||
#include "FeatureDescriptorFactory.h"
|
||||
|
||||
|
||||
using namespace winrt::Windows::AI::MachineLearning;
|
||||
|
||||
|
|
@ -111,7 +115,7 @@ public:
|
|||
*tensor = tensor_outer.Detach();
|
||||
return S_OK;
|
||||
}
|
||||
};
|
||||
}; // class AbiSafeOrtValue
|
||||
|
||||
class ModelProto : public Microsoft::WRL::RuntimeClass<
|
||||
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
|
||||
|
|
@ -128,12 +132,178 @@ public:
|
|||
|
||||
private:
|
||||
std::shared_ptr<onnx::ModelProto> model_proto_;
|
||||
};
|
||||
}; // class ModelProto
|
||||
|
||||
|
||||
class ModelInfo : public Microsoft::WRL::RuntimeClass<
|
||||
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
|
||||
IModelInfo> {
|
||||
|
||||
private:
|
||||
std::string author_;
|
||||
std::string name_;
|
||||
std::string domain_;
|
||||
std::string description_;
|
||||
int64_t version_;
|
||||
std::unordered_map<std::string, std::string> model_metadata_;
|
||||
wfc::IVector<winml::ILearningModelFeatureDescriptor> input_features_;
|
||||
wfc::IVector<winml::ILearningModelFeatureDescriptor> output_features_;
|
||||
|
||||
public:
|
||||
|
||||
ModelInfo(const onnx::ModelProto* model_proto) {
|
||||
Initialize(model_proto);
|
||||
}
|
||||
|
||||
std::string STDMETHODCALLTYPE author() override {
|
||||
return author_;
|
||||
}
|
||||
std::string STDMETHODCALLTYPE name() override {
|
||||
return name_;
|
||||
}
|
||||
std::string STDMETHODCALLTYPE domain() override {
|
||||
return domain_;
|
||||
}
|
||||
std::string STDMETHODCALLTYPE description() override {
|
||||
return description_;
|
||||
}
|
||||
int64_t STDMETHODCALLTYPE version() override {
|
||||
return version_;
|
||||
}
|
||||
std::unordered_map<std::string, std::string> STDMETHODCALLTYPE model_metadata() override {
|
||||
return model_metadata_;
|
||||
}
|
||||
wfc::IVector<winml::ILearningModelFeatureDescriptor> STDMETHODCALLTYPE input_features() override {
|
||||
return input_features_;
|
||||
}
|
||||
wfc::IVector<winml::ILearningModelFeatureDescriptor> STDMETHODCALLTYPE output_features() override {
|
||||
return output_features_;
|
||||
}
|
||||
|
||||
static std::vector<const char*>
|
||||
GetAllNodeOutputs(const onnx::ModelProto& model_proto) {
|
||||
std::vector<const char*> nodes_outputs;
|
||||
auto& graph = model_proto.graph();
|
||||
auto& nodes = graph.node();
|
||||
for (auto& node : nodes) {
|
||||
for (auto& node_output : node.output()) {
|
||||
nodes_outputs.push_back(node_output.c_str());
|
||||
}
|
||||
}
|
||||
return nodes_outputs;
|
||||
}
|
||||
|
||||
static std::vector<const char*>
|
||||
GetInitializers(const onnx::ModelProto& model_proto) {
|
||||
std::vector<const char*> initializers;
|
||||
auto& graph = model_proto.graph();
|
||||
auto& graph_initializers = graph.initializer();
|
||||
for (auto& initializer : graph_initializers) {
|
||||
initializers.push_back(initializer.name().c_str());
|
||||
}
|
||||
return initializers;
|
||||
}
|
||||
|
||||
static std::vector<const onnx::ValueInfoProto*>
|
||||
GetInputsWithoutInitializers(const onnx::ModelProto& model_proto) {
|
||||
auto initializers = GetInitializers(model_proto);
|
||||
|
||||
std::vector<const onnx::ValueInfoProto*> inputs_without_initializers;
|
||||
auto& graph = model_proto.graph();
|
||||
auto& inputs = graph.input();
|
||||
for (auto& input : inputs) {
|
||||
if (input.has_name() && input.has_type()) {
|
||||
auto found_it = std::find_if(
|
||||
std::begin(initializers),
|
||||
std::end(initializers),
|
||||
[&](auto& initializer) {
|
||||
return std::strcmp(initializer, input.name().c_str()) == 0;
|
||||
});
|
||||
|
||||
auto is_initializer = found_it != std::end(initializers);
|
||||
if (!is_initializer) {
|
||||
inputs_without_initializers.push_back(&input);
|
||||
}
|
||||
}
|
||||
}
|
||||
return inputs_without_initializers;
|
||||
}
|
||||
|
||||
static
|
||||
std::vector<const onnx::ValueInfoProto*> GetOutputs(const onnx::ModelProto& model_proto) {
|
||||
std::vector<const onnx::ValueInfoProto*> outputs_with_name;
|
||||
auto& graph = model_proto.graph();
|
||||
auto& outputs = graph.output();
|
||||
for (auto& output : outputs) {
|
||||
if (output.has_name() && output.has_type()) {
|
||||
outputs_with_name.push_back(&output);
|
||||
}
|
||||
}
|
||||
return outputs_with_name;
|
||||
}
|
||||
|
||||
private:
|
||||
void Initialize(const onnx::ModelProto* model_proto) {
|
||||
// metadata
|
||||
for (auto& prop : model_proto->metadata_props()) {
|
||||
model_metadata_[prop.key()] = prop.value();
|
||||
}
|
||||
|
||||
WinML::FeatureDescriptorFactory builder(model_metadata_);
|
||||
|
||||
// Create inputs
|
||||
auto inputs = GetInputsWithoutInitializers(*model_proto);
|
||||
input_features_ = builder.CreateDescriptorsFromValueInfoProtos(inputs);
|
||||
|
||||
// Create outputs
|
||||
auto outputs = GetOutputs(*model_proto);
|
||||
output_features_ = builder.CreateDescriptorsFromValueInfoProtos(outputs);
|
||||
|
||||
// author
|
||||
auto has_producer_name = model_proto->has_producer_name();
|
||||
author_ = has_producer_name
|
||||
? model_proto->producer_name()
|
||||
: "";
|
||||
|
||||
// domain
|
||||
auto has_domain = model_proto->has_domain();
|
||||
domain_ = has_domain
|
||||
? model_proto->domain()
|
||||
: "";
|
||||
|
||||
// name
|
||||
auto has_graph = model_proto->has_graph();
|
||||
auto graph_has_name = model_proto->graph().has_name();
|
||||
auto is_name_available = has_graph && graph_has_name;
|
||||
name_ = is_name_available
|
||||
? model_proto->graph().name()
|
||||
: "";
|
||||
|
||||
// description
|
||||
auto has_description = model_proto->has_doc_string();
|
||||
description_ = has_description
|
||||
? model_proto->doc_string()
|
||||
: "";
|
||||
|
||||
// version
|
||||
auto has_version = model_proto->has_model_version();
|
||||
version_ = has_version
|
||||
? model_proto->model_version()
|
||||
: 0;
|
||||
}
|
||||
}; // class ModelInfo
|
||||
|
||||
class WinMLAdapter : public Microsoft::WRL::RuntimeClass<
|
||||
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
|
||||
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
|
||||
IWinMLAdapter> {
|
||||
private:
|
||||
std::shared_ptr<WinML::LotusEnvironment> lotus_environment_;
|
||||
|
||||
public:
|
||||
WinMLAdapter() : lotus_environment_(PheonixSingleton<WinML::LotusEnvironment>()) {
|
||||
|
||||
}
|
||||
|
||||
// factory methods for creating an ort model from a path
|
||||
HRESULT STDMETHODCALLTYPE CreateModelProto(
|
||||
const char* path,
|
||||
|
|
@ -188,6 +358,12 @@ public:
|
|||
return model_proto_outer.CopyTo(__uuidof(IModelProto), (void**)model_proto);
|
||||
}
|
||||
|
||||
HRESULT STDMETHODCALLTYPE CreateModelInfo(IModelProto * model_proto, IModelInfo ** model_info) override {
|
||||
auto model_info_outer = wil::MakeOrThrow<ModelInfo>(model_proto->get());
|
||||
return model_info_outer.CopyTo(__uuidof(IModelInfo), (void**)model_info);
|
||||
}
|
||||
|
||||
|
||||
void STDMETHODCALLTYPE EnableDebugOutput() override {
|
||||
WinML::CWinMLLogSink::EnableDebugOutput();
|
||||
}
|
||||
|
|
@ -516,6 +692,19 @@ public:
|
|||
return S_OK;
|
||||
}
|
||||
|
||||
// Override select shape inference functions which are incomplete in ONNX with versions that are complete,
|
||||
// and are also used in DML kernel registrations. Doing this avoids kernel and shader creation being
|
||||
// deferred until first evaluation. It also prevents a situation where inference functions in externally
|
||||
// registered schema are reachable only after upstream schema have been revised in a later OS release,
|
||||
// which would be a compatibility risk.
|
||||
HRESULT STDMETHODCALLTYPE OverrideSchemaInferenceFunctions() override {
|
||||
static std::once_flag schema_override_once_flag;
|
||||
std::call_once(schema_override_once_flag, []() {
|
||||
SchemaInferenceOverrider::OverrideSchemaInferenceFunctions();
|
||||
});
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -3,27 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "WinMLAdapter.h"
|
||||
|
||||
namespace Windows::AI::MachineLearning {
|
||||
|
||||
class ModelInfo {
|
||||
public:
|
||||
ModelInfo(const onnx::ModelProto* model_proto);
|
||||
|
||||
public:
|
||||
// model metadata
|
||||
std::string author_;
|
||||
std::string name_;
|
||||
std::string domain_;
|
||||
std::string description_;
|
||||
int64_t version_;
|
||||
std::unordered_map<std::string, std::string> model_metadata_;
|
||||
wfc::IVector<winml::ILearningModelFeatureDescriptor> input_features_;
|
||||
wfc::IVector<winml::ILearningModelFeatureDescriptor> output_features_;
|
||||
|
||||
private:
|
||||
void Initialize(const onnx::ModelProto* model_proto);
|
||||
};
|
||||
|
||||
} // namespace Windows::AI::MachineLearning
|
||||
|
|
@ -4,9 +4,22 @@
|
|||
#pragma once
|
||||
|
||||
#include "IOrtSessionBuilder.h"
|
||||
#include "ModelInfo.h"
|
||||
|
||||
namespace Windows::AI::MachineLearning::Adapter {
|
||||
|
||||
MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") IModelInfo : IUnknown{
|
||||
// model metadata
|
||||
virtual std::string STDMETHODCALLTYPE author() = 0;
|
||||
virtual std::string STDMETHODCALLTYPE name() = 0;
|
||||
virtual std::string STDMETHODCALLTYPE domain() = 0;
|
||||
virtual std::string STDMETHODCALLTYPE description() = 0;
|
||||
virtual int64_t STDMETHODCALLTYPE version() = 0;
|
||||
virtual std::unordered_map<std::string, std::string> STDMETHODCALLTYPE model_metadata() = 0;
|
||||
virtual wfc::IVector<winml::ILearningModelFeatureDescriptor> STDMETHODCALLTYPE input_features() = 0;
|
||||
virtual wfc::IVector<winml::ILearningModelFeatureDescriptor> STDMETHODCALLTYPE output_features() = 0;
|
||||
};
|
||||
|
||||
MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") ITensor : IUnknown{
|
||||
// these all return weak pointers
|
||||
virtual const onnxruntime::Tensor& STDMETHODCALLTYPE get() = 0;
|
||||
|
|
@ -92,14 +105,11 @@ MIDL_INTERFACE("b19385e7-d9af-441a-ba7f-3993c7b1c9db") IWinMLAdapter : IUnknown
|
|||
ID3D12CommandQueue* queue,
|
||||
IOrtSessionBuilder** session_builder) = 0;
|
||||
|
||||
// factory methods for creating an ort model from a path
|
||||
// factory methods for creating model protos
|
||||
virtual HRESULT STDMETHODCALLTYPE CreateModelProto(const char* path, IModelProto** model_proto) = 0;
|
||||
|
||||
// factory methods for creating an ort model from a stream
|
||||
virtual HRESULT STDMETHODCALLTYPE CreateModelProto(ABI::Windows::Storage::Streams::IRandomAccessStreamReference* stream_reference, IModelProto** model_proto) = 0;
|
||||
|
||||
// factory methods for creating an ort model from a model_proto
|
||||
virtual HRESULT STDMETHODCALLTYPE CreateModelProto(IModelProto * model_proto_in, IModelProto** model_proto) = 0;
|
||||
virtual HRESULT STDMETHODCALLTYPE CreateModelInfo(IModelProto * model_proto, IModelInfo ** model_info) = 0;
|
||||
|
||||
// Data types
|
||||
virtual onnxruntime::MLDataType STDMETHODCALLTYPE GetTensorType() = 0;
|
||||
|
|
@ -142,7 +152,8 @@ MIDL_INTERFACE("b19385e7-d9af-441a-ba7f-3993c7b1c9db") IWinMLAdapter : IUnknown
|
|||
onnxruntime::MLDataType data_type,
|
||||
IOrtValue ** ort_value) = 0;
|
||||
|
||||
|
||||
// schema overrides (dml does this for us)
|
||||
virtual HRESULT STDMETHODCALLTYPE OverrideSchemaInferenceFunctions() = 0;
|
||||
};
|
||||
|
||||
extern "C"
|
||||
|
|
|
|||
|
|
@ -5,10 +5,8 @@
|
|||
|
||||
#include "LearningModel.h"
|
||||
|
||||
#include "core/providers/dml/OperatorAuthorHelper/SchemaInferenceOverrider.h"
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h"
|
||||
#include "ModelInfo.h"
|
||||
#include "PheonixSingleton.h"
|
||||
#include "TelemetryEvent.h"
|
||||
|
||||
#include "LotusEnvironment.h"
|
||||
|
|
@ -21,21 +19,19 @@ namespace winrt::Windows::AI::MachineLearning::implementation {
|
|||
LearningModel::LearningModel(
|
||||
const hstring& path,
|
||||
const winml::ILearningModelOperatorProvider op_provider) try : LearningModel(WinML::Strings::UTF8FromHString(path),
|
||||
op_provider) {}
|
||||
op_provider) {
|
||||
}
|
||||
WINML_CATCH_ALL
|
||||
|
||||
LearningModel::LearningModel(
|
||||
const std::string& path,
|
||||
const winml::ILearningModelOperatorProvider operator_provider) try : lotus_environment_(PheonixSingleton<WinML::LotusEnvironment>()),
|
||||
operator_provider_(operator_provider) {
|
||||
const winml::ILearningModelOperatorProvider operator_provider) try : operator_provider_(operator_provider) {
|
||||
_winmlt::PerformanceTelemetryEvent kLoadModel_event(
|
||||
WinMLRuntimePerf::kLoadModel);
|
||||
|
||||
OverrideShapeInferenceMethods();
|
||||
|
||||
com_ptr<_winmla::IWinMLAdapter> adapter;
|
||||
WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put()));
|
||||
WINML_THROW_IF_FAILED(adapter->CreateModelProto(path.c_str(), model_proto_.put()));
|
||||
WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put()));
|
||||
WINML_THROW_IF_FAILED(adapter_->OverrideSchemaInferenceFunctions());
|
||||
WINML_THROW_IF_FAILED(adapter_->CreateModelProto(path.c_str(), model_proto_.put()));
|
||||
|
||||
Initialize();
|
||||
|
||||
|
|
@ -45,19 +41,16 @@ WINML_CATCH_ALL
|
|||
|
||||
LearningModel::LearningModel(
|
||||
const wss::IRandomAccessStreamReference stream,
|
||||
const winml::ILearningModelOperatorProvider operator_provider) try : lotus_environment_(PheonixSingleton<WinML::LotusEnvironment>()),
|
||||
operator_provider_(operator_provider) {
|
||||
const winml::ILearningModelOperatorProvider operator_provider) try : operator_provider_(operator_provider) {
|
||||
_winmlt::PerformanceTelemetryEvent kLoadModel_event(
|
||||
WinMLRuntimePerf::kLoadModel);
|
||||
|
||||
OverrideShapeInferenceMethods();
|
||||
|
||||
com_ptr<_winmla::IWinMLAdapter> adapter;
|
||||
WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put()));
|
||||
WINML_THROW_IF_FAILED(adapter->CreateModelProto(
|
||||
static_cast<ABI::Windows::Storage::Streams::IRandomAccessStreamReference*>(winrt::get_abi(stream)),
|
||||
WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put()));
|
||||
WINML_THROW_IF_FAILED(adapter_->OverrideSchemaInferenceFunctions());
|
||||
WINML_THROW_IF_FAILED(adapter_->CreateModelProto(
|
||||
static_cast<ABI::Windows::Storage::Streams::IRandomAccessStreamReference*>(winrt::get_abi(stream)),
|
||||
model_proto_.put()));
|
||||
|
||||
|
||||
Initialize();
|
||||
|
||||
LogCreationEvent(true);
|
||||
|
|
@ -65,8 +58,7 @@ LearningModel::LearningModel(
|
|||
WINML_CATCH_ALL
|
||||
|
||||
void LearningModel::Initialize() {
|
||||
model_info_ = std::make_unique<WinML::ModelInfo>(
|
||||
model_proto_.get()->get());
|
||||
WINML_THROW_IF_FAILED(adapter_->CreateModelInfo(model_proto_.get(), model_info_.put()));
|
||||
}
|
||||
|
||||
void LearningModel::LogCreationEvent(bool fromStream) {
|
||||
|
|
@ -80,13 +72,13 @@ void LearningModel::LogCreationEvent(bool fromStream) {
|
|||
}
|
||||
telemetry_helper.LogModelCreation(
|
||||
fromStream,
|
||||
model_info_->author_,
|
||||
model_info_->name_,
|
||||
model_info_->domain_,
|
||||
model_info_->description_,
|
||||
model_info_->version_,
|
||||
model_info_->author(),
|
||||
model_info_->name(),
|
||||
model_info_->domain(),
|
||||
model_info_->description(),
|
||||
model_info_->version(),
|
||||
use_fp16,
|
||||
model_info_->model_metadata_);
|
||||
model_info_->model_metadata());
|
||||
}
|
||||
|
||||
void LearningModel::ModelUseFP16(
|
||||
|
|
@ -119,41 +111,41 @@ void LearningModel::ModelUseFP16(
|
|||
|
||||
hstring
|
||||
LearningModel::Author() try {
|
||||
return WinML::Strings::HStringFromUTF8(model_info_->author_);
|
||||
return WinML::Strings::HStringFromUTF8(model_info_->author());
|
||||
}
|
||||
WINML_CATCH_ALL
|
||||
|
||||
hstring
|
||||
LearningModel::Name() try {
|
||||
return WinML::Strings::HStringFromUTF8(
|
||||
model_info_->name_);
|
||||
model_info_->name());
|
||||
}
|
||||
WINML_CATCH_ALL
|
||||
|
||||
hstring
|
||||
LearningModel::Domain() try {
|
||||
return WinML::Strings::HStringFromUTF8(
|
||||
model_info_->domain_);
|
||||
model_info_->domain());
|
||||
}
|
||||
WINML_CATCH_ALL
|
||||
|
||||
hstring
|
||||
LearningModel::Description() try {
|
||||
return WinML::Strings::HStringFromUTF8(
|
||||
model_info_->description_);
|
||||
model_info_->description());
|
||||
}
|
||||
WINML_CATCH_ALL
|
||||
|
||||
int64_t
|
||||
LearningModel::Version() try {
|
||||
return model_info_->version_;
|
||||
return model_info_->version();
|
||||
}
|
||||
WINML_CATCH_ALL
|
||||
|
||||
wfc::IMapView<hstring, hstring>
|
||||
LearningModel::Metadata() try {
|
||||
std::unordered_map<hstring, hstring> map_copy;
|
||||
for (auto& pair : model_info_->model_metadata_) {
|
||||
for (auto& pair : model_info_->model_metadata()) {
|
||||
auto key = WinML::Strings::HStringFromUTF8(pair.first);
|
||||
auto value = WinML::Strings::HStringFromUTF8(pair.second);
|
||||
map_copy.emplace(std::move(key), std::move(value));
|
||||
|
|
@ -183,13 +175,13 @@ LearningModel::GetOperatorRegistry() {
|
|||
|
||||
wfc::IVectorView<winml::ILearningModelFeatureDescriptor>
|
||||
LearningModel::InputFeatures() try {
|
||||
return model_info_->input_features_.GetView();
|
||||
return model_info_->input_features().GetView();
|
||||
}
|
||||
WINML_CATCH_ALL
|
||||
|
||||
wfc::IVectorView<winml::ILearningModelFeatureDescriptor>
|
||||
LearningModel::OutputFeatures() try {
|
||||
return model_info_->output_features_.GetView();
|
||||
return model_info_->output_features().GetView();
|
||||
}
|
||||
WINML_CATCH_ALL
|
||||
|
||||
|
|
@ -287,18 +279,6 @@ LearningModel::CopyModelProto() {
|
|||
return model_proto.detach();
|
||||
}
|
||||
|
||||
static std::once_flag g_schema_override_once_flag;
|
||||
|
||||
// Override select shape inference functions which are incomplete in ONNX with versions that are complete,
|
||||
// and are also used in DML kernel registrations. Doing this avoids kernel and shader creation being
|
||||
// deferred until first evaluation. It also prevents a situation where inference functions in externally
|
||||
// registered schema are reachable only after upstream schema have been revised in a later OS release,
|
||||
// which would be a compatibility risk.
|
||||
void LearningModel::OverrideShapeInferenceMethods() {
|
||||
std::call_once(g_schema_override_once_flag, []() {
|
||||
SchemaInferenceOverrider::OverrideSchemaInferenceFunctions();
|
||||
});
|
||||
}
|
||||
} // namespace winrt::Windows::AI::MachineLearning::implementation
|
||||
|
||||
namespace winrt::Windows::AI::MachineLearning::factory_implementation {
|
||||
|
|
|
|||
|
|
@ -6,11 +6,6 @@
|
|||
#include "LearningModel.g.h"
|
||||
#include "WinMLAdapter.h"
|
||||
|
||||
namespace Windows::AI::MachineLearning {
|
||||
class LotusEnvironment;
|
||||
class ModelInfo;
|
||||
} // namespace Windows::AI::MachineLearning
|
||||
|
||||
namespace winrt::Windows::AI::MachineLearning::implementation {
|
||||
|
||||
struct LearningModel : LearningModelT<LearningModel> {
|
||||
|
|
@ -121,13 +116,10 @@ struct LearningModel : LearningModelT<LearningModel> {
|
|||
winml::ILearningModelFeatureDescriptor descriptor,
|
||||
bool& use_fp16);
|
||||
|
||||
void
|
||||
OverrideShapeInferenceMethods();
|
||||
|
||||
private:
|
||||
std::shared_ptr<WinML::LotusEnvironment> lotus_environment_;
|
||||
com_ptr<_winmla::IWinMLAdapter> adapter_;
|
||||
com_ptr<_winmla::IModelProto> model_proto_;
|
||||
std::unique_ptr<WinML::ModelInfo> model_info_;
|
||||
com_ptr<_winmla::IModelInfo> model_info_;
|
||||
ILearningModelOperatorProvider operator_provider_;
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue