2019-11-08 00:50:24 +00:00
|
|
|
// Copyright (c) Microsoft Corporation.
|
|
|
|
|
// Licensed under the MIT License.
|
|
|
|
|
|
|
|
|
|
#include "pch.h"
|
|
|
|
|
#include "inc/WinMLAdapter.h"
|
|
|
|
|
#include "inc/CustomRegistryHelper.h"
|
2019-11-15 18:54:44 +00:00
|
|
|
#include "PheonixSingleton.h"
|
2019-11-08 00:50:24 +00:00
|
|
|
#include "inc/LotusEnvironment.h"
|
2019-11-15 01:44:07 +00:00
|
|
|
#include "inc/AbiCustomRegistryImpl.h"
|
2019-11-23 00:13:28 +00:00
|
|
|
|
|
|
|
|
#ifdef USE_DML
|
2019-11-08 00:50:24 +00:00
|
|
|
#include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h"
|
2019-11-15 01:44:07 +00:00
|
|
|
#include "core/providers/dml/GraphTransformers/GraphTransformerHelpers.h"
|
2019-11-15 18:54:44 +00:00
|
|
|
#include "core/providers/dml/OperatorAuthorHelper/SchemaInferenceOverrider.h"
|
2019-11-23 00:13:28 +00:00
|
|
|
#include "DmlOrtSessionBuilder.h"
|
|
|
|
|
#endif USE_DML
|
2019-11-08 00:50:24 +00:00
|
|
|
|
|
|
|
|
#include "LearningModelDevice.h"
|
|
|
|
|
#include "TensorFeatureDescriptor.h"
|
|
|
|
|
#include "ImageFeatureDescriptor.h"
|
|
|
|
|
#include "api.image/inc/D3DDeviceCache.h"
|
2019-11-22 15:23:20 +00:00
|
|
|
#include "Common/inc/WinMLTelemetryHelper.h"
|
2019-11-08 00:50:24 +00:00
|
|
|
|
2019-11-15 01:44:07 +00:00
|
|
|
#include "CpuOrtSessionBuilder.h"
|
2019-11-08 00:50:24 +00:00
|
|
|
|
2019-11-15 01:44:07 +00:00
|
|
|
#include <io.h>
|
|
|
|
|
#include <fcntl.h>
|
2019-11-08 00:50:24 +00:00
|
|
|
|
2019-11-15 01:44:07 +00:00
|
|
|
#include "ZeroCopyInputStreamWrapper.h"
|
|
|
|
|
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
2019-11-08 00:50:24 +00:00
|
|
|
|
2019-11-15 18:54:44 +00:00
|
|
|
#include "FeatureDescriptorFactory.h"
|
|
|
|
|
|
2019-11-15 01:44:07 +00:00
|
|
|
using namespace winrt::Windows::AI::MachineLearning;
|
2019-11-08 00:50:24 +00:00
|
|
|
|
2019-11-15 01:44:07 +00:00
|
|
|
namespace Windows::AI::MachineLearning::Adapter {
|
2019-11-08 00:50:24 +00:00
|
|
|
|
2019-11-22 15:23:20 +00:00
|
|
|
// Define winml trace logging provider with WinML GUID
|
|
|
|
|
TRACELOGGING_DEFINE_PROVIDER(
|
|
|
|
|
winml_trace_logging_provider,
|
|
|
|
|
WINML_PROVIDER_DESC,
|
|
|
|
|
WINML_PROVIDER_GUID);
|
|
|
|
|
|
2019-11-08 00:50:24 +00:00
|
|
|
// ORT intentionally requires callers derive from their session class to access
|
2019-11-25 20:50:04 +00:00
|
|
|
// the protected methods used below.
|
2019-11-08 00:50:24 +00:00
|
|
|
class InferenceSessionProtectedLoadAccessor : public onnxruntime::InferenceSession {
|
|
|
|
|
public:
|
|
|
|
|
onnxruntime::common::Status
|
|
|
|
|
Load(std::unique_ptr<ONNX_NAMESPACE::ModelProto> p_model_proto) {
|
|
|
|
|
return onnxruntime::InferenceSession::Load(std::move(p_model_proto));
|
|
|
|
|
}
|
2019-11-25 20:50:04 +00:00
|
|
|
const onnxruntime::SessionState& GetSessionState() {
|
|
|
|
|
return session_state_;
|
2019-11-19 02:31:04 +00:00
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
};
|
|
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
class ModelProto : public Microsoft::WRL::RuntimeClass<
|
|
|
|
|
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
|
|
|
|
|
IModelProto> {
|
|
|
|
|
public:
|
|
|
|
|
ModelProto::ModelProto(onnx::ModelProto* model_proto) : model_proto_(model_proto) {
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
onnx::ModelProto* STDMETHODCALLTYPE get() override {
|
|
|
|
|
return model_proto_.get();
|
|
|
|
|
}
|
2019-11-18 17:50:25 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
onnx::ModelProto* STDMETHODCALLTYPE detach() override {
|
|
|
|
|
return model_proto_.release();
|
|
|
|
|
}
|
2019-11-15 18:54:44 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
private:
|
|
|
|
|
std::unique_ptr<onnx::ModelProto> model_proto_;
|
|
|
|
|
}; // class ModelProto
|
2019-11-15 18:54:44 +00:00
|
|
|
|
|
|
|
|
class ModelInfo : public Microsoft::WRL::RuntimeClass<
|
2019-11-19 02:31:04 +00:00
|
|
|
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
|
|
|
|
|
IModelInfo> {
|
|
|
|
|
private:
|
|
|
|
|
std::string author_;
|
|
|
|
|
std::string name_;
|
|
|
|
|
std::string domain_;
|
|
|
|
|
std::string description_;
|
|
|
|
|
int64_t version_;
|
|
|
|
|
std::unordered_map<std::string, std::string> model_metadata_;
|
|
|
|
|
wfc::IVector<winml::ILearningModelFeatureDescriptor> input_features_;
|
|
|
|
|
wfc::IVector<winml::ILearningModelFeatureDescriptor> output_features_;
|
2019-11-15 18:54:44 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
public:
|
|
|
|
|
ModelInfo(const onnx::ModelProto* model_proto) {
|
|
|
|
|
Initialize(model_proto);
|
|
|
|
|
}
|
2019-11-15 18:54:44 +00:00
|
|
|
|
2019-11-27 23:04:47 +00:00
|
|
|
const char* STDMETHODCALLTYPE author() override {
|
|
|
|
|
return author_.c_str();
|
2019-11-19 02:31:04 +00:00
|
|
|
}
|
2019-11-27 23:04:47 +00:00
|
|
|
const char* STDMETHODCALLTYPE name() override {
|
|
|
|
|
return name_.c_str();
|
2019-11-19 02:31:04 +00:00
|
|
|
}
|
2019-11-27 23:04:47 +00:00
|
|
|
const char* STDMETHODCALLTYPE domain() override {
|
|
|
|
|
return domain_.c_str();
|
2019-11-19 02:31:04 +00:00
|
|
|
}
|
2019-11-27 23:04:47 +00:00
|
|
|
const char* STDMETHODCALLTYPE description() override {
|
|
|
|
|
return description_.c_str();
|
2019-11-19 02:31:04 +00:00
|
|
|
}
|
|
|
|
|
int64_t STDMETHODCALLTYPE version() override {
|
|
|
|
|
return version_;
|
|
|
|
|
}
|
2019-11-27 23:04:47 +00:00
|
|
|
HRESULT STDMETHODCALLTYPE GetModelMetadata(
|
|
|
|
|
ABI::Windows::Foundation::Collections::IMapView<HSTRING, HSTRING>** metadata) override {
|
|
|
|
|
*metadata = nullptr;
|
|
|
|
|
std::unordered_map<winrt::hstring, winrt::hstring> map_copy;
|
|
|
|
|
for (auto& pair : model_metadata_) {
|
|
|
|
|
auto key = WinML::Strings::HStringFromUTF8(pair.first);
|
|
|
|
|
auto map_value = WinML::Strings::HStringFromUTF8(pair.second);
|
|
|
|
|
map_copy.emplace(std::move(key), std::move(map_value));
|
|
|
|
|
}
|
|
|
|
|
auto out = winrt::single_threaded_map<winrt::hstring, winrt::hstring>(
|
|
|
|
|
std::move(map_copy));
|
|
|
|
|
|
|
|
|
|
winrt::copy_to_abi(out.GetView(), *(void**)metadata);
|
|
|
|
|
return S_OK;
|
2019-11-19 02:31:04 +00:00
|
|
|
}
|
2019-11-27 23:04:47 +00:00
|
|
|
|
|
|
|
|
HRESULT STDMETHODCALLTYPE GetInputFeatures(
|
|
|
|
|
ABI::Windows::Foundation::Collections::IVectorView<winml::ILearningModelFeatureDescriptor>** features) override{
|
|
|
|
|
*features = nullptr;
|
|
|
|
|
winrt::copy_to_abi(input_features_.GetView(), *(void**)features);
|
|
|
|
|
return S_OK;
|
2019-11-19 02:31:04 +00:00
|
|
|
}
|
2019-11-27 23:04:47 +00:00
|
|
|
HRESULT STDMETHODCALLTYPE GetOutputFeatures(
|
|
|
|
|
ABI::Windows::Foundation::Collections::IVectorView<winml::ILearningModelFeatureDescriptor>** features) override {
|
|
|
|
|
*features = nullptr;
|
|
|
|
|
winrt::copy_to_abi(output_features_.GetView(), *(void**)features);
|
|
|
|
|
return S_OK;
|
2019-11-19 02:31:04 +00:00
|
|
|
}
|
2019-11-15 18:54:44 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
static std::vector<const char*>
|
|
|
|
|
GetAllNodeOutputs(const onnx::ModelProto& model_proto) {
|
|
|
|
|
std::vector<const char*> nodes_outputs;
|
|
|
|
|
auto& graph = model_proto.graph();
|
|
|
|
|
auto& nodes = graph.node();
|
|
|
|
|
for (auto& node : nodes) {
|
|
|
|
|
for (auto& node_output : node.output()) {
|
|
|
|
|
nodes_outputs.push_back(node_output.c_str());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nodes_outputs;
|
|
|
|
|
}
|
2019-11-15 18:54:44 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
static std::vector<const char*>
|
|
|
|
|
GetInitializers(const onnx::ModelProto& model_proto) {
|
|
|
|
|
std::vector<const char*> initializers;
|
|
|
|
|
auto& graph = model_proto.graph();
|
|
|
|
|
auto& graph_initializers = graph.initializer();
|
|
|
|
|
for (auto& initializer : graph_initializers) {
|
|
|
|
|
initializers.push_back(initializer.name().c_str());
|
2019-11-15 18:54:44 +00:00
|
|
|
}
|
2019-11-19 02:31:04 +00:00
|
|
|
return initializers;
|
|
|
|
|
}
|
2019-11-15 18:54:44 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
static std::vector<const onnx::ValueInfoProto*>
|
|
|
|
|
GetInputsWithoutInitializers(const onnx::ModelProto& model_proto) {
|
|
|
|
|
auto initializers = GetInitializers(model_proto);
|
|
|
|
|
|
|
|
|
|
std::vector<const onnx::ValueInfoProto*> inputs_without_initializers;
|
|
|
|
|
auto& graph = model_proto.graph();
|
|
|
|
|
auto& inputs = graph.input();
|
|
|
|
|
for (auto& input : inputs) {
|
|
|
|
|
if (input.has_name() && input.has_type()) {
|
|
|
|
|
auto found_it = std::find_if(
|
|
|
|
|
std::begin(initializers),
|
|
|
|
|
std::end(initializers),
|
|
|
|
|
[&](auto& initializer) {
|
|
|
|
|
return std::strcmp(initializer, input.name().c_str()) == 0;
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
auto is_initializer = found_it != std::end(initializers);
|
|
|
|
|
if (!is_initializer) {
|
|
|
|
|
inputs_without_initializers.push_back(&input);
|
2019-11-15 18:54:44 +00:00
|
|
|
}
|
2019-11-19 02:31:04 +00:00
|
|
|
}
|
2019-11-15 18:54:44 +00:00
|
|
|
}
|
2019-11-19 02:31:04 +00:00
|
|
|
return inputs_without_initializers;
|
|
|
|
|
}
|
2019-11-15 18:54:44 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
static std::vector<const onnx::ValueInfoProto*> GetOutputs(const onnx::ModelProto& model_proto) {
|
|
|
|
|
std::vector<const onnx::ValueInfoProto*> outputs_with_name;
|
|
|
|
|
auto& graph = model_proto.graph();
|
|
|
|
|
auto& outputs = graph.output();
|
|
|
|
|
for (auto& output : outputs) {
|
|
|
|
|
if (output.has_name() && output.has_type()) {
|
|
|
|
|
outputs_with_name.push_back(&output);
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
}
|
2019-11-19 02:31:04 +00:00
|
|
|
return outputs_with_name;
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
private:
|
|
|
|
|
void Initialize(const onnx::ModelProto* model_proto) {
|
|
|
|
|
// metadata
|
|
|
|
|
for (auto& prop : model_proto->metadata_props()) {
|
|
|
|
|
model_metadata_[prop.key()] = prop.value();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
WinML::FeatureDescriptorFactory builder(model_metadata_);
|
|
|
|
|
|
|
|
|
|
// Create inputs
|
|
|
|
|
auto inputs = GetInputsWithoutInitializers(*model_proto);
|
|
|
|
|
input_features_ = builder.CreateDescriptorsFromValueInfoProtos(inputs);
|
|
|
|
|
|
|
|
|
|
// Create outputs
|
|
|
|
|
auto outputs = GetOutputs(*model_proto);
|
|
|
|
|
output_features_ = builder.CreateDescriptorsFromValueInfoProtos(outputs);
|
|
|
|
|
|
|
|
|
|
// author
|
|
|
|
|
auto has_producer_name = model_proto->has_producer_name();
|
|
|
|
|
author_ = has_producer_name
|
|
|
|
|
? model_proto->producer_name()
|
|
|
|
|
: "";
|
|
|
|
|
|
|
|
|
|
// domain
|
|
|
|
|
auto has_domain = model_proto->has_domain();
|
|
|
|
|
domain_ = has_domain
|
|
|
|
|
? model_proto->domain()
|
|
|
|
|
: "";
|
|
|
|
|
|
|
|
|
|
// name
|
|
|
|
|
auto has_graph = model_proto->has_graph();
|
|
|
|
|
auto graph_has_name = model_proto->graph().has_name();
|
|
|
|
|
auto is_name_available = has_graph && graph_has_name;
|
|
|
|
|
name_ = is_name_available
|
|
|
|
|
? model_proto->graph().name()
|
|
|
|
|
: "";
|
|
|
|
|
|
|
|
|
|
// description
|
|
|
|
|
auto has_description = model_proto->has_doc_string();
|
|
|
|
|
description_ = has_description
|
|
|
|
|
? model_proto->doc_string()
|
|
|
|
|
: "";
|
|
|
|
|
|
|
|
|
|
// version
|
|
|
|
|
auto has_version = model_proto->has_model_version();
|
|
|
|
|
version_ = has_version
|
|
|
|
|
? model_proto->model_version()
|
|
|
|
|
: 0;
|
|
|
|
|
}
|
|
|
|
|
}; // class ModelInfo
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
class WinMLAdapter : public Microsoft::WRL::RuntimeClass<
|
|
|
|
|
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
|
|
|
|
|
IWinMLAdapter> {
|
|
|
|
|
private:
|
|
|
|
|
std::shared_ptr<WinML::LotusEnvironment> lotus_environment_;
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
public:
|
|
|
|
|
WinMLAdapter() : lotus_environment_(PheonixSingleton<WinML::LotusEnvironment>()) {
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
// factory methods for creating an ort model from a path
|
|
|
|
|
HRESULT STDMETHODCALLTYPE CreateModelProto(
|
|
|
|
|
const char* path,
|
|
|
|
|
IModelProto** model_proto) override {
|
|
|
|
|
int file_descriptor;
|
|
|
|
|
_set_errno(0); // clear errno
|
|
|
|
|
_sopen_s(
|
|
|
|
|
&file_descriptor,
|
|
|
|
|
path,
|
|
|
|
|
O_RDONLY | _O_SEQUENTIAL | _O_BINARY,
|
|
|
|
|
_SH_DENYWR,
|
|
|
|
|
_S_IREAD | _S_IWRITE);
|
|
|
|
|
|
|
|
|
|
errno_t err = 0;
|
|
|
|
|
_get_errno(&err);
|
|
|
|
|
THROW_HR_IF_MSG(
|
|
|
|
|
__HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND),
|
|
|
|
|
err == ENOENT,
|
|
|
|
|
"File not found: %s",
|
|
|
|
|
path);
|
|
|
|
|
|
|
|
|
|
THROW_HR_IF_MSG(
|
|
|
|
|
E_FAIL,
|
|
|
|
|
0 > file_descriptor,
|
|
|
|
|
"Failed"); //errno
|
|
|
|
|
|
|
|
|
|
auto stream = google::protobuf::io::FileInputStream(file_descriptor);
|
|
|
|
|
stream.SetCloseOnDelete(true);
|
|
|
|
|
|
|
|
|
|
auto model_proto_inner = new onnx::ModelProto();
|
|
|
|
|
THROW_HR_IF_MSG(
|
|
|
|
|
E_INVALIDARG,
|
|
|
|
|
model_proto_inner->ParseFromZeroCopyStream(&stream) == false,
|
|
|
|
|
"The stream failed to parse.");
|
|
|
|
|
|
|
|
|
|
auto model_proto_outer = wil::MakeOrThrow<ModelProto>(model_proto_inner);
|
|
|
|
|
return model_proto_outer.CopyTo(__uuidof(IModelProto), reinterpret_cast<void**>(model_proto));
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
// factory methods for creating an ort model from a stream
|
|
|
|
|
HRESULT STDMETHODCALLTYPE CreateModelProto(
|
|
|
|
|
ABI::Windows::Storage::Streams::IRandomAccessStreamReference* stream_reference,
|
|
|
|
|
IModelProto** model_proto) override {
|
|
|
|
|
ZeroCopyInputStreamWrapper wrapper(stream_reference);
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
auto model_proto_inner = new onnx::ModelProto();
|
|
|
|
|
THROW_HR_IF_MSG(
|
|
|
|
|
E_INVALIDARG,
|
|
|
|
|
model_proto_inner->ParseFromZeroCopyStream(&wrapper) == false,
|
|
|
|
|
"The stream failed to parse.");
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
auto model_proto_outer = wil::MakeOrThrow<ModelProto>(model_proto_inner);
|
|
|
|
|
return model_proto_outer.CopyTo(__uuidof(IModelProto), reinterpret_cast<void**>(model_proto));
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
// factory methods for creating an ort model from a model_proto
|
|
|
|
|
HRESULT STDMETHODCALLTYPE CreateModelProto(IModelProto* model_proto_in, IModelProto** model_proto) override {
|
|
|
|
|
auto model_proto_inner = new onnx::ModelProto(*model_proto_in->get());
|
|
|
|
|
auto model_proto_outer = wil::MakeOrThrow<ModelProto>(model_proto_inner);
|
|
|
|
|
return model_proto_outer.CopyTo(__uuidof(IModelProto), reinterpret_cast<void**>(model_proto));
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
HRESULT STDMETHODCALLTYPE CreateModelInfo(IModelProto* model_proto, IModelInfo** model_info) override {
|
|
|
|
|
auto model_info_outer = wil::MakeOrThrow<ModelInfo>(model_proto->get());
|
|
|
|
|
return model_info_outer.CopyTo(__uuidof(IModelInfo), reinterpret_cast<void**>(model_info));
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
void STDMETHODCALLTYPE EnableDebugOutput() override {
|
|
|
|
|
WinML::CWinMLLogSink::EnableDebugOutput();
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
static bool IsFeatureDescriptorFp16(
|
|
|
|
|
winml::ILearningModelFeatureDescriptor descriptor) {
|
|
|
|
|
if (auto imageFeatureDescriptor = descriptor.try_as<winml::IImageFeatureDescriptor2>()) {
|
|
|
|
|
return TensorKind::Float16 == imageFeatureDescriptor.TensorKind();
|
2019-11-15 01:44:07 +00:00
|
|
|
}
|
|
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
if (auto tensorFeatureDescriptor = descriptor.try_as<winml::ITensorFeatureDescriptor>()) {
|
|
|
|
|
return TensorKind::Float16 == tensorFeatureDescriptor.TensorKind();
|
2019-11-15 01:44:07 +00:00
|
|
|
}
|
|
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
return false;
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
HRESULT STDMETHODCALLTYPE EnsureModelDeviceCompatibility(
|
|
|
|
|
winml::LearningModel const& model,
|
|
|
|
|
IModelProto* p_model_proto,
|
|
|
|
|
bool is_float16_supported) override {
|
|
|
|
|
if (!is_float16_supported) {
|
|
|
|
|
auto& graph = p_model_proto->get()->graph();
|
|
|
|
|
|
|
|
|
|
// The model will not contain fp16 operations if:
|
|
|
|
|
// 1. The model has no fp16 inputs
|
|
|
|
|
// 2. The model has no fp16 initializers
|
|
|
|
|
// 3. The model does not create any fp16 intermediary tensors via the Cast (to float16) operator
|
|
|
|
|
// 4. The model does not have any fp16 outputs
|
|
|
|
|
|
|
|
|
|
// 1. Ensure that The model has no fp16 inputs
|
|
|
|
|
for (auto descriptor : model.InputFeatures()) {
|
|
|
|
|
THROW_HR_IF_MSG(
|
|
|
|
|
DXGI_ERROR_UNSUPPORTED,
|
|
|
|
|
IsFeatureDescriptorFp16(descriptor),
|
|
|
|
|
"The model contains a 16-bit input (%ls), but the current device does not support 16-bit float.",
|
|
|
|
|
descriptor.Name().c_str());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 2. Ensure that the model has no fp16 initializers
|
|
|
|
|
for (int i = 0; i < graph.node_size(); i++) {
|
|
|
|
|
auto node = graph.node(i);
|
|
|
|
|
if (node.op_type() == "Cast" && node.domain().empty()) {
|
|
|
|
|
for (int attribIndex = 0; attribIndex < node.attribute_size(); attribIndex++) {
|
|
|
|
|
auto attribute = node.attribute(attribIndex);
|
|
|
|
|
if (attribute.name() == "to") {
|
|
|
|
|
THROW_HR_IF_MSG(
|
|
|
|
|
DXGI_ERROR_UNSUPPORTED,
|
|
|
|
|
attribute.i() == onnx::TensorProto::DataType::TensorProto_DataType_FLOAT16,
|
|
|
|
|
"The model contains a 16-bit float Cast Op (%s), but the current device does not support 16-bit float.",
|
|
|
|
|
node.name().c_str());
|
2019-11-15 01:44:07 +00:00
|
|
|
}
|
2019-11-19 02:31:04 +00:00
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
}
|
2019-11-19 02:31:04 +00:00
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
// 3. Ensure that the model does not create any fp16 intermediary
|
|
|
|
|
// tensors via the Cast (to float16) operator
|
|
|
|
|
for (int i = 0; i < graph.initializer_size(); i++) {
|
|
|
|
|
auto initializer = graph.initializer(i);
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
THROW_HR_IF_MSG(
|
|
|
|
|
DXGI_ERROR_UNSUPPORTED,
|
|
|
|
|
initializer.data_type() == onnx::TensorProto::DataType::TensorProto_DataType_FLOAT16,
|
|
|
|
|
"The model contains a 16-bit float initializer (%s), but the current device does not support 16-bit float.",
|
|
|
|
|
initializer.name().c_str());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 4. Ensure that the model does not have any fp16 outputs
|
|
|
|
|
for (auto descriptor : model.OutputFeatures()) {
|
|
|
|
|
THROW_HR_IF_MSG(
|
|
|
|
|
DXGI_ERROR_UNSUPPORTED,
|
|
|
|
|
IsFeatureDescriptorFp16(descriptor),
|
|
|
|
|
"The model contains a 16-bit output (%ls), but the current device does not support 16-bit float.",
|
|
|
|
|
descriptor.Name().c_str());
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
}
|
2019-11-19 02:31:04 +00:00
|
|
|
return S_OK;
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
ID3D12Resource* STDMETHODCALLTYPE GetD3D12ResourceFromAllocation(onnxruntime::IExecutionProvider* provider, void* allocation) override {
|
2019-11-23 00:13:28 +00:00
|
|
|
#ifdef USE_DML
|
2019-11-19 02:31:04 +00:00
|
|
|
auto d3dResource =
|
|
|
|
|
Dml::GetD3D12ResourceFromAllocation(
|
|
|
|
|
provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault).get(),
|
|
|
|
|
allocation);
|
|
|
|
|
return d3dResource;
|
2019-11-23 00:13:28 +00:00
|
|
|
#else
|
|
|
|
|
return nullptr;
|
|
|
|
|
#endif USE_DML
|
2019-11-19 02:31:04 +00:00
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
static onnxruntime::MLDataType GetType(winml::TensorKind kind) {
|
|
|
|
|
switch (kind) {
|
|
|
|
|
case winml::TensorKind::Float:
|
|
|
|
|
return onnxruntime::DataTypeImpl::GetType<float>();
|
|
|
|
|
case winml::TensorKind::Float16:
|
|
|
|
|
return onnxruntime::DataTypeImpl::GetType<onnxruntime::MLFloat16>();
|
|
|
|
|
};
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
// factory method for creating an ortsessionbuilder from a device
|
|
|
|
|
HRESULT STDMETHODCALLTYPE CreateOrtSessionBuilder(
|
|
|
|
|
ID3D12Device* device,
|
|
|
|
|
ID3D12CommandQueue* queue,
|
|
|
|
|
IOrtSessionBuilder** session_builder) override {
|
|
|
|
|
if (device == nullptr) {
|
|
|
|
|
auto builder = wil::MakeOrThrow<CpuOrtSessionBuilder>();
|
|
|
|
|
return builder.CopyTo(__uuidof(IOrtSessionBuilder), reinterpret_cast<void**>(session_builder));
|
2019-11-23 00:13:28 +00:00
|
|
|
}
|
|
|
|
|
#ifdef USE_DML
|
|
|
|
|
else {
|
2019-11-19 02:31:04 +00:00
|
|
|
auto builder = wil::MakeOrThrow<DmlOrtSessionBuilder>(device, queue);
|
|
|
|
|
return builder.CopyTo(__uuidof(IOrtSessionBuilder), reinterpret_cast<void**>(session_builder));
|
2019-11-15 01:44:07 +00:00
|
|
|
}
|
2019-11-23 00:13:28 +00:00
|
|
|
#else
|
|
|
|
|
return E_NOTIMPL;
|
|
|
|
|
#endif USE_DML
|
2019-11-19 02:31:04 +00:00
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-25 20:50:04 +00:00
|
|
|
HRESULT STDMETHODCALLTYPE GetMapType(const OrtValue* ort_value, ONNXTensorElementDataType* key_type, ONNXTensorElementDataType* value_type) override {
|
|
|
|
|
*key_type = *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
|
|
|
|
auto type = ort_value->Type();
|
|
|
|
|
if (type == onnxruntime::DataTypeImpl::GetType<onnxruntime::MapStringToString>()) {
|
|
|
|
|
*key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
|
|
|
|
*value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
|
|
|
|
} else if (type == onnxruntime::DataTypeImpl::GetType<onnxruntime::MapStringToInt64>()) {
|
|
|
|
|
*key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
|
|
|
|
*value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
|
|
|
|
} else if (type == onnxruntime::DataTypeImpl::GetType<onnxruntime::MapStringToFloat>()) {
|
|
|
|
|
*key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
|
|
|
|
*value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
|
|
|
|
} else if (type == onnxruntime::DataTypeImpl::GetType<onnxruntime::MapStringToDouble>()) {
|
|
|
|
|
*key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
|
|
|
|
*value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
|
|
|
|
|
} else if (type == onnxruntime::DataTypeImpl::GetType<onnxruntime::MapInt64ToString>()) {
|
|
|
|
|
*key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
|
|
|
|
*value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
|
|
|
|
} else if (type == onnxruntime::DataTypeImpl::GetType<onnxruntime::MapInt64ToInt64>()) {
|
|
|
|
|
*key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
|
|
|
|
*value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
|
|
|
|
} else if (type == onnxruntime::DataTypeImpl::GetType<onnxruntime::MapInt64ToFloat>()) {
|
|
|
|
|
*key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
|
|
|
|
*value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
|
|
|
|
} else if (type == onnxruntime::DataTypeImpl::GetType<onnxruntime::MapInt64ToDouble>()) {
|
|
|
|
|
*key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
|
|
|
|
*value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
|
2019-11-19 02:31:04 +00:00
|
|
|
}
|
2019-11-25 20:50:04 +00:00
|
|
|
return S_OK;
|
2019-11-19 02:31:04 +00:00
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-25 20:50:04 +00:00
|
|
|
HRESULT STDMETHODCALLTYPE GetVectorMapType(const OrtValue* ort_value, ONNXTensorElementDataType* key_type, ONNXTensorElementDataType* value_type) override {
|
|
|
|
|
*key_type = *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
|
|
|
|
auto type = ort_value->Type();
|
|
|
|
|
if (type == onnxruntime::DataTypeImpl::GetType<onnxruntime::VectorMapStringToFloat>()) {
|
|
|
|
|
*key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
|
|
|
|
*value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
|
|
|
|
} else if (type == onnxruntime::DataTypeImpl::GetType<onnxruntime::VectorMapInt64ToFloat>()) {
|
|
|
|
|
*key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
|
|
|
|
*value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
2019-11-19 02:31:04 +00:00
|
|
|
}
|
2019-11-25 20:50:04 +00:00
|
|
|
return S_OK;
|
2019-11-19 02:31:04 +00:00
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-25 20:50:04 +00:00
|
|
|
HRESULT STDMETHODCALLTYPE GetCustomRegistry(IMLOperatorRegistry** registry) override {
|
2019-11-23 00:13:28 +00:00
|
|
|
#ifdef USE_DML
|
2019-11-25 20:50:04 +00:00
|
|
|
auto impl = wil::MakeOrThrow<AbiCustomRegistryImpl>();
|
|
|
|
|
*registry = impl.Detach();
|
|
|
|
|
return S_OK;
|
2019-11-23 00:13:28 +00:00
|
|
|
#else
|
2019-11-25 19:11:30 +00:00
|
|
|
return E_NOTIMPL;
|
2019-11-23 00:13:28 +00:00
|
|
|
#endif USE_DML
|
2019-11-19 02:31:04 +00:00
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-25 21:23:32 +00:00
|
|
|
HRESULT STDMETHODCALLTYPE GetOperatorRegistry(ILearningModelOperatorProviderNative* operator_provider_native, IMLOperatorRegistry** registry) override {
|
2019-11-23 00:13:28 +00:00
|
|
|
#ifdef USE_DML
|
2019-11-25 19:11:30 +00:00
|
|
|
// Retrieve the "operator abi" registry.
|
|
|
|
|
winrt::com_ptr<IMLOperatorRegistry> operator_registry;
|
|
|
|
|
THROW_IF_FAILED(operator_provider_native->GetRegistry(operator_registry.put()));
|
|
|
|
|
*registry = operator_registry.detach();
|
|
|
|
|
return S_OK;
|
2019-11-23 00:13:28 +00:00
|
|
|
#else
|
2019-11-25 19:11:30 +00:00
|
|
|
return E_NOTIMPL;
|
2019-11-23 00:13:28 +00:00
|
|
|
#endif USE_DML
|
2019-11-25 21:23:32 +00:00
|
|
|
}
|
2019-11-15 18:54:44 +00:00
|
|
|
|
2019-11-25 20:50:04 +00:00
|
|
|
void* STDMETHODCALLTYPE CreateGPUAllocationFromD3DResource(ID3D12Resource* pResource) override {
|
2019-11-23 00:13:28 +00:00
|
|
|
#ifdef USE_DML
|
2019-11-25 20:50:04 +00:00
|
|
|
return Dml::CreateGPUAllocationFromD3DResource(pResource);
|
2019-11-25 19:11:30 +00:00
|
|
|
#else
|
|
|
|
|
return nullptr;
|
2019-11-23 00:13:28 +00:00
|
|
|
#endif USE_DML
|
2019-11-19 02:31:04 +00:00
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-25 20:50:04 +00:00
|
|
|
void STDMETHODCALLTYPE FreeGPUAllocation(void* ptr) override {
|
2019-11-23 00:13:28 +00:00
|
|
|
#ifdef USE_DML
|
2019-11-25 20:50:04 +00:00
|
|
|
Dml::FreeGPUAllocation(ptr);
|
2019-11-25 19:11:30 +00:00
|
|
|
#endif USE_DML
|
2019-11-25 20:50:04 +00:00
|
|
|
}
|
2019-11-25 19:11:30 +00:00
|
|
|
|
2019-11-25 20:50:04 +00:00
|
|
|
HRESULT STDMETHODCALLTYPE CopyTensor(
|
|
|
|
|
onnxruntime::IExecutionProvider* provider,
|
2019-11-25 21:31:56 +00:00
|
|
|
OrtValue* src,
|
|
|
|
|
OrtValue* dst) override {
|
|
|
|
|
#ifdef USE_DML
|
|
|
|
|
ORT_THROW_IF_ERROR(Dml::CopyTensor(provider, *(src->GetMutable<onnxruntime::Tensor>()), *(dst->GetMutable<onnxruntime::Tensor>())));
|
2019-11-25 20:50:04 +00:00
|
|
|
return S_OK;
|
2019-11-23 00:13:28 +00:00
|
|
|
#else
|
2019-11-25 19:11:30 +00:00
|
|
|
return E_NOTIMPL;
|
2019-11-23 00:13:28 +00:00
|
|
|
#endif USE_DML
|
2019-11-19 02:31:04 +00:00
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-25 20:50:04 +00:00
|
|
|
// Override select shape inference functions which are incomplete in ONNX with versions that are complete,
|
|
|
|
|
// and are also used in DML kernel registrations. Doing this avoids kernel and shader creation being
|
|
|
|
|
// deferred until first evaluation. It also prevents a situation where inference functions in externally
|
|
|
|
|
// registered schema are reachable only after upstream schema have been revised in a later OS release,
|
|
|
|
|
// which would be a compatibility risk.
|
|
|
|
|
HRESULT STDMETHODCALLTYPE OverrideSchemaInferenceFunctions() override {
|
2019-11-23 00:13:28 +00:00
|
|
|
#ifdef USE_DML
|
2019-11-25 20:50:04 +00:00
|
|
|
static std::once_flag schema_override_once_flag;
|
|
|
|
|
std::call_once(schema_override_once_flag, []() {
|
|
|
|
|
SchemaInferenceOverrider::OverrideSchemaInferenceFunctions();
|
|
|
|
|
});
|
|
|
|
|
return S_OK;
|
2019-11-23 00:13:28 +00:00
|
|
|
#else
|
2019-11-25 21:23:32 +00:00
|
|
|
return S_OK; // needs to return S_OK otherwise everything breaks because this gets called from the learningmodel constructor
|
2019-11-23 00:13:28 +00:00
|
|
|
#endif USE_DML
|
2019-11-25 20:50:04 +00:00
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-25 20:50:04 +00:00
|
|
|
HRESULT STDMETHODCALLTYPE GetProviderMemoryInfo(
|
|
|
|
|
onnxruntime::IExecutionProvider* provider,
|
|
|
|
|
OrtMemoryInfo** memory_info) override {
|
|
|
|
|
auto allocator = provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault);
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-25 20:50:04 +00:00
|
|
|
const auto& info = allocator->Info();
|
|
|
|
|
*memory_info = new OrtMemoryInfo(info.name, info.type, info.device, info.id, info.mem_type);
|
|
|
|
|
if (*memory_info == nullptr) {
|
|
|
|
|
return E_OUTOFMEMORY;
|
|
|
|
|
}
|
|
|
|
|
return S_OK;
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-25 20:50:04 +00:00
|
|
|
HRESULT STDMETHODCALLTYPE GetValueMemoryInfo(const OrtValue* ort_value, OrtMemoryInfo** memory_info) override {
|
|
|
|
|
const auto& tensor = ort_value->Get<onnxruntime::Tensor>();
|
|
|
|
|
auto info = tensor.Location();
|
|
|
|
|
*memory_info = new OrtMemoryInfo(info.name, info.type, info.device, info.id, info.mem_type);
|
|
|
|
|
if (*memory_info == nullptr) {
|
|
|
|
|
return E_OUTOFMEMORY;
|
|
|
|
|
}
|
|
|
|
|
return S_OK;
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-25 20:50:04 +00:00
|
|
|
struct AllocatorWrapper : public OrtAllocator {
|
|
|
|
|
public:
|
|
|
|
|
AllocatorWrapper(onnxruntime::AllocatorPtr impl) : impl_(impl) {
|
|
|
|
|
version = ORT_API_VERSION;
|
|
|
|
|
Alloc = AllocImpl;
|
|
|
|
|
Free = FreeImpl;
|
|
|
|
|
Info = InfoImpl;
|
|
|
|
|
}
|
2019-11-15 18:54:44 +00:00
|
|
|
|
2019-11-25 20:50:04 +00:00
|
|
|
static void* ORT_API_CALL AllocImpl(struct OrtAllocator* this_, size_t size) {
|
|
|
|
|
return static_cast<AllocatorWrapper*>(this_)->impl_->Alloc(size);
|
|
|
|
|
}
|
|
|
|
|
static void ORT_API_CALL FreeImpl(struct OrtAllocator* this_, void* p) {
|
|
|
|
|
return static_cast<AllocatorWrapper*>(this_)->impl_->Free(p);
|
|
|
|
|
}
|
|
|
|
|
static const struct OrtMemoryInfo* ORT_API_CALL InfoImpl(const struct OrtAllocator* this_) {
|
2019-11-27 02:53:08 +00:00
|
|
|
return &(static_cast<const AllocatorWrapper*>(this_)->impl_->Info());
|
2019-11-25 20:50:04 +00:00
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-25 20:50:04 +00:00
|
|
|
private:
|
|
|
|
|
onnxruntime::AllocatorPtr impl_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
HRESULT STDMETHODCALLTYPE GetProviderAllocator(
|
|
|
|
|
onnxruntime::IExecutionProvider* provider,
|
|
|
|
|
OrtAllocator** allocator) override {
|
|
|
|
|
auto allocator_ptr = provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault);
|
|
|
|
|
*allocator = new AllocatorWrapper(allocator_ptr);
|
|
|
|
|
if (*allocator == nullptr) {
|
|
|
|
|
return E_OUTOFMEMORY;
|
|
|
|
|
}
|
2019-11-19 18:54:51 +00:00
|
|
|
|
2019-11-25 20:50:04 +00:00
|
|
|
return S_OK;
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
}; // namespace Windows::AI::MachineLearning::Adapter
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
extern "C" HRESULT STDMETHODCALLTYPE OrtGetWinMLAdapter(IWinMLAdapter** adapter) {
|
|
|
|
|
// make an adapter instance
|
|
|
|
|
Microsoft::WRL::ComPtr<WinMLAdapter> adapterptr = wil::MakeOrThrow<WinMLAdapter>();
|
|
|
|
|
return adapterptr.CopyTo(__uuidof(IWinMLAdapter), reinterpret_cast<void**>(adapter));
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
|
|
|
|
// class IOBinding
|
|
|
|
|
// ===============
|
2019-11-19 02:31:04 +00:00
|
|
|
class IOBinding : public Microsoft::WRL::RuntimeClass<
|
|
|
|
|
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
|
|
|
|
|
IIOBinding> {
|
|
|
|
|
private:
|
|
|
|
|
std::shared_ptr<onnxruntime::IOBinding> binding_;
|
2019-11-25 20:50:04 +00:00
|
|
|
std::vector<OrtValue*> outputs_weak_;
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
public:
|
|
|
|
|
IOBinding(onnxruntime::IOBinding* binding) : binding_(binding) {
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
onnxruntime::IOBinding* STDMETHODCALLTYPE get() override {
|
|
|
|
|
return binding_.get();
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-25 20:50:04 +00:00
|
|
|
HRESULT STDMETHODCALLTYPE BindInput(const std::string& name, OrtValue* ort_value) override {
|
|
|
|
|
ORT_THROW_IF_ERROR(binding_->BindInput(name, *ort_value));
|
2019-11-19 02:31:04 +00:00
|
|
|
return S_OK;
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-25 20:50:04 +00:00
|
|
|
HRESULT STDMETHODCALLTYPE BindOutput(const std::string& name, OrtValue* ort_value) override {
|
2019-11-19 02:31:04 +00:00
|
|
|
// this can be null for unbound outputs
|
2019-11-25 20:50:04 +00:00
|
|
|
if (ort_value == nullptr) {
|
2019-11-19 02:31:04 +00:00
|
|
|
OrtValue empty_value = {};
|
|
|
|
|
ORT_THROW_IF_ERROR(binding_->BindOutput(name, empty_value));
|
|
|
|
|
} else {
|
2019-11-25 20:50:04 +00:00
|
|
|
ORT_THROW_IF_ERROR(binding_->BindOutput(name, *ort_value));
|
2019-11-15 01:44:07 +00:00
|
|
|
}
|
2019-11-19 02:31:04 +00:00
|
|
|
return S_OK;
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
const std::vector<std::string>& STDMETHODCALLTYPE GetOutputNames() override {
|
|
|
|
|
return binding_->GetOutputNames();
|
|
|
|
|
}
|
2019-11-25 20:50:04 +00:00
|
|
|
std::vector<OrtValue*>& STDMETHODCALLTYPE GetOutputs() override {
|
2019-11-19 03:46:32 +00:00
|
|
|
auto& output_inner = binding_->GetOutputs();
|
2019-11-19 02:31:04 +00:00
|
|
|
outputs_weak_.clear();
|
|
|
|
|
for (unsigned i = 0; i < output_inner.size(); i++) {
|
2019-11-25 20:50:04 +00:00
|
|
|
outputs_weak_.push_back(&(output_inner[i]));
|
2019-11-19 02:31:04 +00:00
|
|
|
}
|
|
|
|
|
return outputs_weak_;
|
|
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// InferenceSession
|
|
|
|
|
// ================
|
|
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
InferenceSession::InferenceSession(onnxruntime::InferenceSession* session) : session_(session) {
|
2019-11-15 01:44:07 +00:00
|
|
|
}
|
|
|
|
|
|
2019-11-27 02:53:08 +00:00
|
|
|
void STDMETHODCALLTYPE InferenceSession::RegisterGraphTransformers() {
|
2019-11-23 00:13:28 +00:00
|
|
|
#ifdef USE_DML
|
2019-11-27 02:53:08 +00:00
|
|
|
// Bug 22973884 : Fix issues with BatchNorm + Add and BatchNorm + Mul handling implicit inputs, and move from Winml to ORT
|
|
|
|
|
GraphTransformerHelpers::RegisterGraphTransformers(session_.get());
|
2019-11-23 00:13:28 +00:00
|
|
|
#endif USE_DML
|
2019-11-15 01:44:07 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
HRESULT STDMETHODCALLTYPE InferenceSession::NewIOBinding(IIOBinding** io_binding) {
|
2019-11-19 02:31:04 +00:00
|
|
|
std::unique_ptr<onnxruntime::IOBinding> binding;
|
|
|
|
|
ORT_THROW_IF_ERROR(this->session_->NewIOBinding(&binding));
|
|
|
|
|
auto io_binding_outer = wil::MakeOrThrow<IOBinding>(binding.release());
|
|
|
|
|
return io_binding_outer.CopyTo(__uuidof(IIOBinding), reinterpret_cast<void**>(io_binding));
|
2019-11-15 01:44:07 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
HRESULT STDMETHODCALLTYPE InferenceSession::Run(const onnxruntime::RunOptions* run_options, IIOBinding* io_binding) {
|
2019-11-19 02:31:04 +00:00
|
|
|
ORT_THROW_IF_ERROR(this->session_->Run(*run_options, *(io_binding->get())));
|
|
|
|
|
return S_OK;
|
2019-11-08 00:50:24 +00:00
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
HRESULT STDMETHODCALLTYPE InferenceSession::StartProfiling() {
|
2019-11-19 02:31:04 +00:00
|
|
|
this->session_->StartProfiling(PheonixSingleton<WinML::LotusEnvironment>()->GetDefaultLogger());
|
|
|
|
|
return S_OK;
|
2019-11-08 00:50:24 +00:00
|
|
|
}
|
2019-11-15 01:44:07 +00:00
|
|
|
HRESULT STDMETHODCALLTYPE InferenceSession::EndProfiling() {
|
2019-11-19 02:31:04 +00:00
|
|
|
this->session_->EndProfiling();
|
|
|
|
|
return S_OK;
|
2019-11-08 00:50:24 +00:00
|
|
|
}
|
|
|
|
|
|
2019-11-15 01:44:07 +00:00
|
|
|
HRESULT STDMETHODCALLTYPE
|
|
|
|
|
InferenceSession::LoadModel(
|
2019-11-19 02:31:04 +00:00
|
|
|
IModelProto* model_proto) {
|
|
|
|
|
auto session_protected_load_accessor =
|
|
|
|
|
static_cast<InferenceSessionProtectedLoadAccessor*>(session_.get());
|
|
|
|
|
// session's like to have their very own copy of the model_proto, use detach()
|
|
|
|
|
std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto_ptr(model_proto->detach());
|
|
|
|
|
ORT_THROW_IF_ERROR(session_protected_load_accessor->Load(std::move(model_proto_ptr)));
|
|
|
|
|
return S_OK;
|
2019-11-08 00:50:24 +00:00
|
|
|
}
|
|
|
|
|
|
2019-11-15 01:44:07 +00:00
|
|
|
HRESULT STDMETHODCALLTYPE
|
|
|
|
|
InferenceSession::RegisterCustomRegistry(
|
2019-11-19 02:31:04 +00:00
|
|
|
IMLOperatorRegistry* registry) {
|
|
|
|
|
RETURN_HR_IF(S_OK, registry == nullptr);
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-25 19:11:30 +00:00
|
|
|
#ifdef USE_DML
|
2019-11-19 02:31:04 +00:00
|
|
|
auto custom_registries = GetLotusCustomRegistries(registry);
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
// Register
|
|
|
|
|
for (auto& custom_registry : custom_registries) {
|
|
|
|
|
ORT_THROW_IF_ERROR(session_->RegisterCustomRegistry(custom_registry));
|
|
|
|
|
}
|
2019-11-25 19:11:30 +00:00
|
|
|
#endif USE_DML
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-11-19 02:31:04 +00:00
|
|
|
return S_OK;
|
2019-11-15 01:44:07 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void STDMETHODCALLTYPE InferenceSession::FlushContext(onnxruntime::IExecutionProvider* dml_provider) {
|
2019-11-23 00:13:28 +00:00
|
|
|
#ifdef USE_DML
|
2019-11-19 02:31:04 +00:00
|
|
|
Dml::FlushContext(dml_provider);
|
2019-11-23 00:13:28 +00:00
|
|
|
#endif USE_DML
|
2019-11-15 01:44:07 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void STDMETHODCALLTYPE InferenceSession::TrimUploadHeap(onnxruntime::IExecutionProvider* dml_provider) {
|
2019-11-23 00:13:28 +00:00
|
|
|
#ifdef USE_DML
|
2019-11-19 02:31:04 +00:00
|
|
|
Dml::TrimUploadHeap(dml_provider);
|
2019-11-23 00:13:28 +00:00
|
|
|
#endif USE_DML
|
2019-11-08 00:50:24 +00:00
|
|
|
}
|
|
|
|
|
|
2019-11-15 01:44:07 +00:00
|
|
|
void STDMETHODCALLTYPE InferenceSession::ReleaseCompletedReferences(onnxruntime::IExecutionProvider* dml_provider) {
|
2019-11-23 00:13:28 +00:00
|
|
|
#ifdef USE_DML
|
2019-11-19 02:31:04 +00:00
|
|
|
Dml::ReleaseCompletedReferences(dml_provider);
|
2019-11-23 00:13:28 +00:00
|
|
|
#endif USE_DML
|
2019-11-08 00:50:24 +00:00
|
|
|
}
|
|
|
|
|
|
2019-11-25 20:50:04 +00:00
|
|
|
HRESULT STDMETHODCALLTYPE InferenceSession::CopyOneInputAcrossDevices(
|
|
|
|
|
const char* input_name,
|
|
|
|
|
const OrtValue* orig_mlvalue,
|
|
|
|
|
OrtValue** new_mlvalue) {
|
|
|
|
|
return E_NOTIMPL;
|
|
|
|
|
}
|
|
|
|
|
|
2019-11-08 00:50:24 +00:00
|
|
|
} // namespace Windows::AI::MachineLearning::Adapter
|