added a wrapper for RoGetActivationFactory to hook back into winml for creating winml objects.

fixes model load.
This commit is contained in:
Paul McDaniel 2019-11-15 16:47:33 -08:00
parent 8f95b7739a
commit 7f9a7f5abe
2 changed files with 93 additions and 1 deletions

View file

@ -41,6 +41,98 @@ static const char* c_supported_nominal_ranges[] =
"NominalRange_0_255"};
namespace Windows::AI::MachineLearning {
// since this code is now running inside ONNXRUNTIME we need to shortcut
// this a bit when creating winrt objects. This will help.
/* extern "C"
HRESULT __stdcall OS_RoGetActivationFactory(HSTRING classId, GUID const& iid, void** factory) noexcept;
#ifdef _M_IX86
#pragma comment(linker, "/alternatename:_OS_RoGetActivationFactory@12=_RoGetActivationFactory@12")
#else
#pragma comment(linker, "/alternatename:OS_RoGetActivationFactory=RoGetActivationFactory")
#endif
*/
bool starts_with(std::wstring_view value, std::wstring_view match) noexcept
{
return 0 == value.compare(0, match.size(), match);
}
EXTERN_C IMAGE_DOS_HEADER __ImageBase;
std::wstring GetModulePath()
{
std::wstring val;
wchar_t modulePath[MAX_PATH] = { 0 };
GetModuleFileNameW((HINSTANCE)&__ImageBase, modulePath, _countof(modulePath));
wchar_t drive[_MAX_DRIVE];
wchar_t dir[_MAX_DIR];
wchar_t filename[_MAX_FNAME];
wchar_t ext[_MAX_EXT];
_wsplitpath_s(modulePath, drive, _MAX_DRIVE, dir, _MAX_DIR, filename, _MAX_FNAME, ext, _MAX_EXT);
val = drive;
val += dir;
return val;
}
extern "C"
int32_t WINRT_CALL WINRT_RoGetActivationFactory(void* classId, winrt::guid const& iid, void** factory) noexcept
{
*factory = nullptr;
HSTRING classId_hstring = (HSTRING)classId;
std::wstring_view name{ WindowsGetStringRawBuffer(classId_hstring, nullptr), WindowsGetStringLen(classId_hstring) };
HMODULE library{ nullptr };
std::wstring winmlDllPath = GetModulePath() + L"Windows.AI.MachineLearning.dll";
if (starts_with(name, L"Windows.AI.MachineLearning."))
{
const wchar_t* libPath = winmlDllPath.c_str();
library = LoadLibraryW(libPath);
}
else
{
return RoGetActivationFactory(classId_hstring, iid, factory);
}
if (!library)
{
return HRESULT_FROM_WIN32(GetLastError());
}
using DllGetActivationFactory = HRESULT __stdcall(HSTRING classId, void** factory);
auto call = reinterpret_cast<DllGetActivationFactory*>(GetProcAddress(library, "DllGetActivationFactory"));
if (!call)
{
HRESULT const hr = HRESULT_FROM_WIN32(GetLastError());
WINRT_VERIFY(FreeLibrary(library));
return hr;
}
winrt::com_ptr<winrt::Windows::Foundation::IActivationFactory> activation_factory;
HRESULT const hr = call(classId_hstring, activation_factory.put_void());
if (FAILED(hr))
{
WINRT_VERIFY(FreeLibrary(library));
return hr;
}
if (winrt::guid(iid) != winrt::guid_of<winrt::Windows::Foundation::IActivationFactory>())
{
return activation_factory->QueryInterface(iid, factory);
}
*factory = activation_factory.detach();
return S_OK;
}
// Forward declare CreateFeatureDescriptor
static winml::ILearningModelFeatureDescriptor
CreateFeatureDescriptor(

View file

@ -327,7 +327,7 @@ public:
auto model_proto_inner = new onnx::ModelProto();
THROW_HR_IF_MSG(
E_INVALIDARG,
!model_proto_inner->ParseFromZeroCopyStream(&stream) == false,
model_proto_inner->ParseFromZeroCopyStream(&stream) == false,
"The stream failed to parse.");
auto model_proto_outer = wil::MakeOrThrow<ModelProto>(model_proto_inner);