mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-28 03:20:58 +00:00
added a wrapper for RoGetActivationFactory to hook back into winml for creating winml objects.
fixes model load.
This commit is contained in:
parent
8f95b7739a
commit
7f9a7f5abe
2 changed files with 93 additions and 1 deletions
|
|
@ -41,6 +41,98 @@ static const char* c_supported_nominal_ranges[] =
|
|||
"NominalRange_0_255"};
|
||||
|
||||
namespace Windows::AI::MachineLearning {
|
||||
|
||||
|
||||
// since this code is now running inside ONNXRUNTIME we need to shortcut
|
||||
// this a bit when creating winrt objects. This will help.
|
||||
|
||||
/* extern "C"
|
||||
HRESULT __stdcall OS_RoGetActivationFactory(HSTRING classId, GUID const& iid, void** factory) noexcept;
|
||||
|
||||
#ifdef _M_IX86
|
||||
#pragma comment(linker, "/alternatename:_OS_RoGetActivationFactory@12=_RoGetActivationFactory@12")
|
||||
#else
|
||||
#pragma comment(linker, "/alternatename:OS_RoGetActivationFactory=RoGetActivationFactory")
|
||||
#endif
|
||||
*/
|
||||
|
||||
bool starts_with(std::wstring_view value, std::wstring_view match) noexcept
|
||||
{
|
||||
return 0 == value.compare(0, match.size(), match);
|
||||
}
|
||||
|
||||
EXTERN_C IMAGE_DOS_HEADER __ImageBase;
|
||||
|
||||
std::wstring GetModulePath()
|
||||
{
|
||||
std::wstring val;
|
||||
wchar_t modulePath[MAX_PATH] = { 0 };
|
||||
GetModuleFileNameW((HINSTANCE)&__ImageBase, modulePath, _countof(modulePath));
|
||||
wchar_t drive[_MAX_DRIVE];
|
||||
wchar_t dir[_MAX_DIR];
|
||||
wchar_t filename[_MAX_FNAME];
|
||||
wchar_t ext[_MAX_EXT];
|
||||
_wsplitpath_s(modulePath, drive, _MAX_DRIVE, dir, _MAX_DIR, filename, _MAX_FNAME, ext, _MAX_EXT);
|
||||
|
||||
val = drive;
|
||||
val += dir;
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
extern "C"
|
||||
int32_t WINRT_CALL WINRT_RoGetActivationFactory(void* classId, winrt::guid const& iid, void** factory) noexcept
|
||||
{
|
||||
*factory = nullptr;
|
||||
HSTRING classId_hstring = (HSTRING)classId;
|
||||
std::wstring_view name{ WindowsGetStringRawBuffer(classId_hstring, nullptr), WindowsGetStringLen(classId_hstring) };
|
||||
HMODULE library{ nullptr };
|
||||
|
||||
std::wstring winmlDllPath = GetModulePath() + L"Windows.AI.MachineLearning.dll";
|
||||
|
||||
if (starts_with(name, L"Windows.AI.MachineLearning."))
|
||||
{
|
||||
const wchar_t* libPath = winmlDllPath.c_str();
|
||||
library = LoadLibraryW(libPath);
|
||||
}
|
||||
else
|
||||
{
|
||||
return RoGetActivationFactory(classId_hstring, iid, factory);
|
||||
}
|
||||
|
||||
if (!library)
|
||||
{
|
||||
return HRESULT_FROM_WIN32(GetLastError());
|
||||
}
|
||||
|
||||
using DllGetActivationFactory = HRESULT __stdcall(HSTRING classId, void** factory);
|
||||
auto call = reinterpret_cast<DllGetActivationFactory*>(GetProcAddress(library, "DllGetActivationFactory"));
|
||||
|
||||
if (!call)
|
||||
{
|
||||
HRESULT const hr = HRESULT_FROM_WIN32(GetLastError());
|
||||
WINRT_VERIFY(FreeLibrary(library));
|
||||
return hr;
|
||||
}
|
||||
|
||||
winrt::com_ptr<winrt::Windows::Foundation::IActivationFactory> activation_factory;
|
||||
HRESULT const hr = call(classId_hstring, activation_factory.put_void());
|
||||
|
||||
if (FAILED(hr))
|
||||
{
|
||||
WINRT_VERIFY(FreeLibrary(library));
|
||||
return hr;
|
||||
}
|
||||
|
||||
if (winrt::guid(iid) != winrt::guid_of<winrt::Windows::Foundation::IActivationFactory>())
|
||||
{
|
||||
return activation_factory->QueryInterface(iid, factory);
|
||||
}
|
||||
|
||||
*factory = activation_factory.detach();
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
// Forward declare CreateFeatureDescriptor
|
||||
static winml::ILearningModelFeatureDescriptor
|
||||
CreateFeatureDescriptor(
|
||||
|
|
|
|||
|
|
@ -327,7 +327,7 @@ public:
|
|||
auto model_proto_inner = new onnx::ModelProto();
|
||||
THROW_HR_IF_MSG(
|
||||
E_INVALIDARG,
|
||||
!model_proto_inner->ParseFromZeroCopyStream(&stream) == false,
|
||||
model_proto_inner->ParseFromZeroCopyStream(&stream) == false,
|
||||
"The stream failed to parse.");
|
||||
|
||||
auto model_proto_outer = wil::MakeOrThrow<ModelProto>(model_proto_inner);
|
||||
|
|
|
|||
Loading…
Reference in a new issue