From 7f9a7f5abef28851a5f1dfb8eab920e6f4eb8a91 Mon Sep 17 00:00:00 2001 From: Paul McDaniel Date: Fri, 15 Nov 2019 16:47:33 -0800 Subject: [PATCH] added a wrapper for RoGetActivationFactory to hook back into winml for creating winml objects. fixes model load. --- .../lib/Api.Core/FeatureDescriptorFactory.cpp | 92 +++++++++++++++++++ winml/lib/Api.Core/WinMLAdapter.cpp | 2 +- 2 files changed, 93 insertions(+), 1 deletion(-) diff --git a/winml/lib/Api.Core/FeatureDescriptorFactory.cpp b/winml/lib/Api.Core/FeatureDescriptorFactory.cpp index 62ea8aba96..cdc5a1b5bc 100644 --- a/winml/lib/Api.Core/FeatureDescriptorFactory.cpp +++ b/winml/lib/Api.Core/FeatureDescriptorFactory.cpp @@ -41,6 +41,98 @@ static const char* c_supported_nominal_ranges[] = "NominalRange_0_255"}; namespace Windows::AI::MachineLearning { + + +// since this code is now running inside ONNXRUNTIME we need to shortcut +// this a bit when creating winrt objects. This will help. + +/* extern "C" +HRESULT __stdcall OS_RoGetActivationFactory(HSTRING classId, GUID const& iid, void** factory) noexcept; + +#ifdef _M_IX86 +#pragma comment(linker, "/alternatename:_OS_RoGetActivationFactory@12=_RoGetActivationFactory@12") +#else +#pragma comment(linker, "/alternatename:OS_RoGetActivationFactory=RoGetActivationFactory") +#endif +*/ + +bool starts_with(std::wstring_view value, std::wstring_view match) noexcept +{ + return 0 == value.compare(0, match.size(), match); +} + +EXTERN_C IMAGE_DOS_HEADER __ImageBase; + +std::wstring GetModulePath() +{ + std::wstring val; + wchar_t modulePath[MAX_PATH] = { 0 }; + GetModuleFileNameW((HINSTANCE)&__ImageBase, modulePath, _countof(modulePath)); + wchar_t drive[_MAX_DRIVE]; + wchar_t dir[_MAX_DIR]; + wchar_t filename[_MAX_FNAME]; + wchar_t ext[_MAX_EXT]; + _wsplitpath_s(modulePath, drive, _MAX_DRIVE, dir, _MAX_DIR, filename, _MAX_FNAME, ext, _MAX_EXT); + + val = drive; + val += dir; + + return val; +} + +extern "C" +int32_t WINRT_CALL WINRT_RoGetActivationFactory(void* classId, winrt::guid const& iid, void** factory) noexcept +{ + *factory = nullptr; + HSTRING classId_hstring = (HSTRING)classId; + std::wstring_view name{ WindowsGetStringRawBuffer(classId_hstring, nullptr), WindowsGetStringLen(classId_hstring) }; + HMODULE library{ nullptr }; + + std::wstring winmlDllPath = GetModulePath() + L"Windows.AI.MachineLearning.dll"; + + if (starts_with(name, L"Windows.AI.MachineLearning.")) + { + const wchar_t* libPath = winmlDllPath.c_str(); + library = LoadLibraryW(libPath); + } + else + { + return RoGetActivationFactory(classId_hstring, iid, factory); + } + + if (!library) + { + return HRESULT_FROM_WIN32(GetLastError()); + } + + using DllGetActivationFactory = HRESULT __stdcall(HSTRING classId, void** factory); + auto call = reinterpret_cast(GetProcAddress(library, "DllGetActivationFactory")); + + if (!call) + { + HRESULT const hr = HRESULT_FROM_WIN32(GetLastError()); + WINRT_VERIFY(FreeLibrary(library)); + return hr; + } + + winrt::com_ptr activation_factory; + HRESULT const hr = call(classId_hstring, activation_factory.put_void()); + + if (FAILED(hr)) + { + WINRT_VERIFY(FreeLibrary(library)); + return hr; + } + + if (winrt::guid(iid) != winrt::guid_of()) + { + return activation_factory->QueryInterface(iid, factory); + } + + *factory = activation_factory.detach(); + return S_OK; +} + // Forward declare CreateFeatureDescriptor static winml::ILearningModelFeatureDescriptor CreateFeatureDescriptor( diff --git a/winml/lib/Api.Core/WinMLAdapter.cpp b/winml/lib/Api.Core/WinMLAdapter.cpp index b40318f381..a60118d5e4 100644 --- a/winml/lib/Api.Core/WinMLAdapter.cpp +++ b/winml/lib/Api.Core/WinMLAdapter.cpp @@ -327,7 +327,7 @@ public: auto model_proto_inner = new onnx::ModelProto(); THROW_HR_IF_MSG( E_INVALIDARG, - !model_proto_inner->ParseFromZeroCopyStream(&stream) == false, + model_proto_inner->ParseFromZeroCopyStream(&stream) == false, "The stream failed to parse."); auto model_proto_outer = wil::MakeOrThrow(model_proto_inner);