onnxruntime/winml/test/common/dllload.cpp
Tiago Koji Castro Shibata 0f5f17c175
WinML CI (#2412)
* Pass flags to build/test WinML in CI

* Add initial CMake config for unit tests in WinML

* Set winml_unittests standard to C++17

* Add WinML API tests and port them to googletest

* Install WinML test collateral

* Add LearningModelSessionAPITests ported to googletest

* Fix WinML test files encoding

* Add GPU tests

* Add parameterized test, skip GPU tests

* Enable precompiled header

* Remove unused code and collateral

* Remove brand images

* Add dllload.cpp

* Remove images not used in API tests

* Add LICENSE.md to image collaterals

* Add models with licenses

* Remove FNS Candy tests

* Add API test models

* Add ModelInSubdirectory

* Install collaterals post-build with copy_if_different, split common lib

* fix warnings

* Link to gtest_main
2019-11-21 16:55:32 -08:00

75 lines
2.3 KiB
C++

#include "Std.h"
#include "fileHelpers.h"
#include <winstring.h>
extern "C"
{
HRESULT __stdcall OS_RoGetActivationFactory(HSTRING classId, GUID const& iid, void** factory) noexcept;
}
#ifdef _M_IX86
#pragma comment(linker, "/alternatename:_OS_RoGetActivationFactory@12=_RoGetActivationFactory@12")
#else
#pragma comment(linker, "/alternatename:OS_RoGetActivationFactory=RoGetActivationFactory")
#endif
bool starts_with(std::wstring_view value, std::wstring_view match) noexcept
{
return 0 == value.compare(0, match.size(), match);
}
HRESULT __stdcall WINRT_RoGetActivationFactory(HSTRING classId_hstring, GUID const& iid, void** factory) noexcept
{
*factory = nullptr;
std::wstring_view name{ WindowsGetStringRawBuffer(classId_hstring, nullptr), WindowsGetStringLen(classId_hstring) };
HMODULE library{ nullptr };
std::wstring winmlDllPath = FileHelpers::GetWinMLPath() + L"Windows.AI.MachineLearning.dll";
if (starts_with(name, L"Windows.AI.MachineLearning."))
{
const wchar_t* libPath = winmlDllPath.c_str();
library = LoadLibraryW(libPath);
}
else
{
return OS_RoGetActivationFactory(classId_hstring, iid, factory);
}
if (!library)
{
return HRESULT_FROM_WIN32(GetLastError());
}
using DllGetActivationFactory = HRESULT __stdcall(HSTRING classId, void** factory);
auto call = reinterpret_cast<DllGetActivationFactory*>(GetProcAddress(library, "DllGetActivationFactory"));
if (!call)
{
HRESULT const hr = HRESULT_FROM_WIN32(GetLastError());
WINRT_VERIFY(FreeLibrary(library));
return hr;
}
winrt::com_ptr<winrt::Windows::Foundation::IActivationFactory> activation_factory;
HRESULT const hr = call(classId_hstring, activation_factory.put_void());
if (FAILED(hr))
{
WINRT_VERIFY(FreeLibrary(library));
return hr;
}
if (winrt::guid(iid) != winrt::guid_of<winrt::Windows::Foundation::IActivationFactory>())
{
return activation_factory->QueryInterface(iid, factory);
}
*factory = activation_factory.detach();
return S_OK;
}
int32_t __stdcall WINRT_RoGetActivationFactory(void* classId, winrt::guid const& iid, void** factory) noexcept
{
return WINRT_RoGetActivationFactory((HSTRING)classId, (GUID)iid, factory);
}