onnxruntime/winml/dll/module.cpp
Paul McDaniel 5350abe19d
LearningModelSession is cleaned up to use the adapter, and parts of b… (#2382)
this is a big PR.    we are going to move it up to layer_dev , which is still a L3 so we are still safe to do work there agile.

we are going to move this into the L3 so that ryan can start doing intergration testing.   

we will pause for a full code review and integration test result prior to going into the L2.

>>>> raw comments from previous commits >>> 

* LearningModelSession is cleaned up to use the adapter, and parts of binding are.
* moved everything in the winmladapter
made it all nano-com using, WRL to construct objects in the ORT side.
base interfaces for everythign for winml to call
cleaned up a bunch of winml to use the base interfaces.
* more pieces
* GetData across the abi.
* renamed some namepsace
cleaned up OrtValue
cleaned up Tensor
cleaned up custom ops.
everything *but* learnignmodel should be clean
* make sure it's building.   winml.dll is still a monolith.
2019-11-14 17:44:07 -08:00

100 lines
No EOL
3.3 KiB
C++

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include "pch.h"
#include <windows.h>
#include <Hstring.h>
#include "WinMLProfiler.h"
#include "LearningModelDevice.h"
using namespace winrt::Windows::AI::MachineLearning::implementation;
void __stdcall OnErrorReported(bool alreadyReported, wil::FailureInfo const &failure) WI_NOEXCEPT {
if (!alreadyReported) {
winrt::hstring message(failure.pszMessage ? failure.pszMessage : L"");
telemetry_helper.LogRuntimeError(
failure.hr,
winrt::to_string(message),
failure.pszFile,
failure.pszFunction,
failure.uLineNumber
);
}
}
extern "C" BOOL WINAPI DllMain(_In_ HINSTANCE hInstance, DWORD dwReason, _In_ void* lpvReserved) {
switch (dwReason) {
case DLL_PROCESS_ATTACH:
DisableThreadLibraryCalls(hInstance);
// Register the TraceLogging provider feeding telemetry. It's OK if this fails;
// trace logging calls just become no-ops.
telemetry_helper.Register();
// Log Dll load
telemetry_helper.LogDllAttachEvent();
// Enable Profiling if the device is sampled at measure level
if (telemetry_helper.IsMeasureSampled()) {
profiler.Enable(ProfilerType::CPU);
profiler.Reset(ProfilerType::CPU);
}
wil::SetResultTelemetryFallback(&OnErrorReported);
break;
case DLL_PROCESS_DETACH:
telemetry_helper.LogRuntimePerf(profiler, true);
// Unregister Trace Logging Provider feeding telemetry
telemetry_helper.UnRegister();
#ifdef NDEBUG
bool dynamicUnload = (lpvReserved == nullptr);
//
// The OS can reclaim memory more quickly and correctly during process shutdown.
// Continue to do this on debug builds due to leak detection tracing.
//
if (dynamicUnload)
#endif
{
LearningModelDevice::DllUnload();
}
break;
}
return true;
}
extern "C" HRESULT WINAPI MLCreateOperatorRegistry(_COM_Outptr_ IMLOperatorRegistry** registry) try {
*registry = nullptr;
winrt::com_ptr<_winmla::IWinMLAdapter> adapter;
WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put()));
return adapter->GetCustomRegistry(registry);
}
CATCH_RETURN();
STDAPI DllCanUnloadNow() {
// The windows.ai.machinelearning.dll should not be freed by
// CoFreeUnusedLibraries since there can be outstanding COM object
// references to many objects (AbiCustomRegistry, IMLOperatorKernelContext,
// IMLOperatorTensor, etc) that are not reference counted in this path.
//
// In order to implement DllCanUnloadNow we would need to reference count
// all of the instances of non-WinRT COM objects that have been shared
// across the dll boundary or harden the boundary APIs to make sure to
// additional outstanding references are not cached by callers.
//
// Identifying and curating the complete list of IUnknown based COM objects
// that are shared out as a consequence of the MLCreateOperatorRegistry API
// will be a complex task to complete in RS5.
//
// As a temporary workaround we simply prevent the windows.ai.machinelearning.dll
// from unloading.
//
// There are no known code paths that rely on opportunistic dll unload.
return S_FALSE;
}
STDAPI DllGetActivationFactory(HSTRING classId, void** factory) {
return WINRT_GetActivationFactory(classId, factory);
}