2019-08-15 22:27:05 +00:00
|
|
|
// Copyright (c) Microsoft Corporation.
|
|
|
|
|
// Licensed under the MIT License.
|
|
|
|
|
|
|
|
|
|
#include "pch.h"
|
|
|
|
|
|
2019-12-03 23:31:22 +00:00
|
|
|
#ifdef USE_DML
|
|
|
|
|
|
2019-08-15 22:27:05 +00:00
|
|
|
// Needed to work around the fact that OnnxRuntime defines ERROR
|
|
|
|
|
#ifdef ERROR
|
|
|
|
|
#undef ERROR
|
|
|
|
|
#endif
|
|
|
|
|
#include "core/session/inference_session.h"
|
|
|
|
|
// Restore ERROR define
|
|
|
|
|
#define ERROR 0
|
|
|
|
|
|
|
|
|
|
#include "DmlOrtSessionBuilder.h"
|
|
|
|
|
|
|
|
|
|
// winml includes
|
|
|
|
|
#include "core/providers/dml/GraphTransformers/GraphTransformerHelpers.h"
|
2019-12-03 23:31:22 +00:00
|
|
|
#include "CustomRegistryHelper.h"
|
2019-08-15 22:27:05 +00:00
|
|
|
#include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h"
|
|
|
|
|
#include "LearningModelDevice.h"
|
|
|
|
|
#include "core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h"
|
|
|
|
|
|
|
|
|
|
// ort includes
|
|
|
|
|
#include "core/framework/op_kernel.h"
|
|
|
|
|
#include "core/framework/op_node_proto_helper.h"
|
|
|
|
|
#include "core/framework/customRegistry.h"
|
|
|
|
|
#include "core/framework/data_transfer.h"
|
2019-11-20 02:15:47 +00:00
|
|
|
#include "core/session/abi_session_options_impl.h"
|
2019-08-15 22:27:05 +00:00
|
|
|
|
|
|
|
|
using namespace Windows::AI::MachineLearning;
|
|
|
|
|
|
2019-11-15 01:44:07 +00:00
|
|
|
namespace Windows::AI::MachineLearning::Adapter {
|
|
|
|
|
|
2019-08-15 22:27:05 +00:00
|
|
|
DmlOrtSessionBuilder::DmlOrtSessionBuilder(
|
2019-11-15 01:44:07 +00:00
|
|
|
ID3D12Device* device,
|
|
|
|
|
ID3D12CommandQueue* queue){
|
2019-11-08 21:23:44 +00:00
|
|
|
device_.copy_from(device);
|
|
|
|
|
queue_.copy_from(queue);
|
|
|
|
|
}
|
2019-08-15 22:27:05 +00:00
|
|
|
|
|
|
|
|
HRESULT
|
|
|
|
|
DmlOrtSessionBuilder::CreateSessionOptions(
|
2019-12-04 22:08:16 +00:00
|
|
|
OrtSessionOptions** options) try {
|
2019-11-20 02:15:47 +00:00
|
|
|
RETURN_HR_IF_NULL(E_POINTER, options);
|
2019-08-15 22:27:05 +00:00
|
|
|
|
2019-11-20 02:15:47 +00:00
|
|
|
Ort::ThrowOnError(Ort::GetApi().CreateSessionOptions(options));
|
|
|
|
|
std::unique_ptr<Ort::SessionOptions> session_options = std::make_unique<Ort::SessionOptions>(*options);
|
2019-11-18 17:51:39 +00:00
|
|
|
|
2019-11-20 02:15:47 +00:00
|
|
|
// set the graph optimization level to all (used to be called level 3)
|
|
|
|
|
session_options->SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
|
2019-08-15 22:27:05 +00:00
|
|
|
|
|
|
|
|
// Disable the mem pattern session option for DML. It will cause problems with how memory is allocated.
|
2019-11-20 02:15:47 +00:00
|
|
|
session_options->DisableMemPattern();
|
2019-08-15 22:27:05 +00:00
|
|
|
|
2019-11-20 02:15:47 +00:00
|
|
|
// all done with the smart ptr
|
|
|
|
|
session_options.release();
|
2019-08-15 22:27:05 +00:00
|
|
|
return S_OK;
|
|
|
|
|
}
|
2019-12-04 22:08:16 +00:00
|
|
|
WINML_CATCH_ALL_COM
|
2019-08-15 22:27:05 +00:00
|
|
|
|
|
|
|
|
static HRESULT
|
|
|
|
|
RegisterCustomRegistry(
|
|
|
|
|
onnxruntime::InferenceSession* p_session,
|
|
|
|
|
IMLOperatorRegistry* registry) {
|
|
|
|
|
if (registry != nullptr) {
|
|
|
|
|
RETURN_HR_IF_NULL(E_POINTER, p_session);
|
|
|
|
|
|
|
|
|
|
auto custom_registries = GetLotusCustomRegistries(registry);
|
|
|
|
|
|
|
|
|
|
// Register
|
|
|
|
|
for (auto& custom_registry : custom_registries) {
|
2019-11-11 22:34:19 +00:00
|
|
|
ORT_THROW_IF_ERROR(p_session->RegisterCustomRegistry(custom_registry));
|
2019-08-15 22:27:05 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return S_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Microsoft::WRL::ComPtr<IDMLDevice> CreateDmlDevice(ID3D12Device* d3d12Device) {
|
|
|
|
|
// Dynamically load DML to avoid WinML taking a static dependency on DirectML.dll
|
|
|
|
|
wil::unique_hmodule dmlDll(LoadLibraryW(L"DirectML.dll"));
|
|
|
|
|
THROW_LAST_ERROR_IF(!dmlDll);
|
|
|
|
|
|
|
|
|
|
auto dmlCreateDevice1Fn = reinterpret_cast<decltype(&DMLCreateDevice1)>(
|
|
|
|
|
GetProcAddress(dmlDll.get(), "DMLCreateDevice1"));
|
|
|
|
|
THROW_LAST_ERROR_IF(!dmlCreateDevice1Fn);
|
|
|
|
|
|
|
|
|
|
DML_CREATE_DEVICE_FLAGS dmlFlags = DML_CREATE_DEVICE_FLAG_NONE;
|
|
|
|
|
|
|
|
|
|
// Enable the DML debug layer in DEBUG builds, if the D3D12 debug layer is also enabled
|
|
|
|
|
#if _DEBUG
|
|
|
|
|
Microsoft::WRL::ComPtr<ID3D12DebugDevice> d3d12DebugDevice;
|
|
|
|
|
if (SUCCEEDED(d3d12Device->QueryInterface(IID_PPV_ARGS(&d3d12DebugDevice)))) {
|
|
|
|
|
d3d12DebugDevice = nullptr;
|
|
|
|
|
dmlFlags |= DML_CREATE_DEVICE_FLAG_DEBUG;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
Microsoft::WRL::ComPtr<IDMLDevice> dmlDevice;
|
|
|
|
|
THROW_IF_FAILED(dmlCreateDevice1Fn(d3d12Device, dmlFlags, DML_FEATURE_LEVEL_2_0, IID_PPV_ARGS(&dmlDevice)));
|
|
|
|
|
|
|
|
|
|
// Keep DirectML.dll loaded by leaking the handle. This is equivalent behavior to if we delay-loaded the DLL.
|
|
|
|
|
dmlDll.release();
|
|
|
|
|
|
|
|
|
|
return dmlDevice;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
HRESULT DmlOrtSessionBuilder::CreateSession(
|
2019-11-20 02:15:47 +00:00
|
|
|
OrtSessionOptions* options,
|
2019-11-27 23:50:49 +00:00
|
|
|
winmla::IInferenceSession** p_session,
|
2019-12-04 22:08:16 +00:00
|
|
|
onnxruntime::IExecutionProvider** pp_provider) try {
|
2019-08-15 22:27:05 +00:00
|
|
|
RETURN_HR_IF_NULL(E_POINTER, p_session);
|
|
|
|
|
RETURN_HR_IF_NULL(E_POINTER, pp_provider);
|
|
|
|
|
RETURN_HR_IF(E_POINTER, *pp_provider != nullptr);
|
|
|
|
|
|
2019-11-08 21:23:44 +00:00
|
|
|
auto p_d3d_device = device_.get();
|
|
|
|
|
auto p_queue = queue_.get();
|
2019-08-15 22:27:05 +00:00
|
|
|
|
|
|
|
|
Microsoft::WRL::ComPtr<IDMLDevice> dmlDevice = CreateDmlDevice(p_d3d_device);
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<onnxruntime::IExecutionProvider> gpu_provider = Dml::CreateExecutionProvider(dmlDevice.Get(), p_queue);
|
2019-11-20 02:15:47 +00:00
|
|
|
auto session = std::make_unique<onnxruntime::InferenceSession>(options->value);
|
2019-08-15 22:27:05 +00:00
|
|
|
|
|
|
|
|
// Cache the provider's raw pointer
|
|
|
|
|
*pp_provider = gpu_provider.get();
|
|
|
|
|
|
2019-11-11 22:34:19 +00:00
|
|
|
ORT_THROW_IF_ERROR(session->RegisterExecutionProvider(std::move(gpu_provider)));
|
2019-08-15 22:27:05 +00:00
|
|
|
|
2019-11-15 01:44:07 +00:00
|
|
|
// assign the session to the out parameter
|
2019-11-27 23:50:49 +00:00
|
|
|
auto sessionptr = wil::MakeOrThrow<winmla::InferenceSession>(session.release());
|
|
|
|
|
RETURN_IF_FAILED(sessionptr.CopyTo(_uuidof(winmla::IInferenceSession), (void**)p_session));
|
2019-08-15 22:27:05 +00:00
|
|
|
|
|
|
|
|
return S_OK;
|
|
|
|
|
}
|
2019-12-04 22:08:16 +00:00
|
|
|
WINML_CATCH_ALL_COM
|
2019-08-15 22:27:05 +00:00
|
|
|
|
|
|
|
|
HRESULT DmlOrtSessionBuilder::Initialize(
|
2019-11-27 23:50:49 +00:00
|
|
|
winmla::IInferenceSession* p_session,
|
2019-12-04 22:08:16 +00:00
|
|
|
onnxruntime::IExecutionProvider* p_provider) try {
|
2019-08-15 22:27:05 +00:00
|
|
|
RETURN_HR_IF_NULL(E_INVALIDARG, p_session);
|
|
|
|
|
RETURN_HR_IF_NULL(E_INVALIDARG, p_provider);
|
|
|
|
|
|
|
|
|
|
// OnnxRuntime uses the default rounding mode when calling the session's allocator.
|
|
|
|
|
// During initialization, OnnxRuntime allocates weights, which are permanent across session
|
|
|
|
|
// lifetime and can be large, so shouldn't be rounded.
|
|
|
|
|
Dml::SetDefaultRoundingMode(p_provider, AllocatorRoundingMode::Disabled);
|
|
|
|
|
|
2019-11-15 01:44:07 +00:00
|
|
|
ORT_THROW_IF_ERROR(p_session->get()->Initialize());
|
2019-08-15 22:27:05 +00:00
|
|
|
|
|
|
|
|
Dml::SetDefaultRoundingMode(p_provider, AllocatorRoundingMode::Enabled);
|
|
|
|
|
|
|
|
|
|
// Flush the D3D12 work from the DML execution provider
|
|
|
|
|
Dml::FlushContext(p_provider);
|
|
|
|
|
|
|
|
|
|
return S_OK;
|
2019-11-15 01:44:07 +00:00
|
|
|
}
|
2019-12-04 22:08:16 +00:00
|
|
|
WINML_CATCH_ALL_COM
|
2019-11-15 01:44:07 +00:00
|
|
|
|
2019-12-03 23:31:22 +00:00
|
|
|
} // Windows::AI::MachineLearning::Adapter
|
|
|
|
|
|
|
|
|
|
#endif USE_DML
|