onnxruntime/winml/test/scenario/cppwinrt/CustomOperatorProvider.h
Tiago Koji Castro Shibata f40d5d80a0
Add scenario tests (#2457)
* Add scenario tests

* Remove TODO from model license

* Add winml_api test dependency
2019-11-25 17:17:23 -08:00

57 lines
1.6 KiB
C

#pragma once
#include "NoisyReluCpu.h"
#include "ReluCpu.h"
struct CustomOperatorProvider :
winrt::implements<
CustomOperatorProvider,
winrt::Windows::AI::MachineLearning::ILearningModelOperatorProvider,
ILearningModelOperatorProviderNative>
{
HMODULE m_library;
winrt::com_ptr<IMLOperatorRegistry> m_registry;
CustomOperatorProvider()
{
#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
m_library = LoadLibraryW(L"windows.ai.machinelearning.dll");
#elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_PC_APP)
m_library = LoadPackagedLibrary(L"windows.ai.machinelearning.dll", 0 /*Reserved*/);
#endif
using create_registry_delegate = HRESULT WINAPI (_COM_Outptr_ IMLOperatorRegistry** registry);
auto create_registry = reinterpret_cast<create_registry_delegate*>(GetProcAddress(m_library, "MLCreateOperatorRegistry"));
if (FAILED(create_registry(m_registry.put())))
{
__fastfail(0);
}
RegisterSchemas();
RegisterKernels();
}
~CustomOperatorProvider()
{
FreeLibrary(m_library);
}
void RegisterSchemas()
{
NoisyReluOperatorFactory::RegisterNoisyReluSchema(m_registry);
}
void RegisterKernels()
{
// Replace the Relu operator kernel
ReluOperatorFactory::RegisterReluKernel(m_registry);
// Add a new operator kernel for Relu
NoisyReluOperatorFactory::RegisterNoisyReluKernel(m_registry);
}
STDMETHOD(GetRegistry)(IMLOperatorRegistry** ppOperatorRegistry)
{
m_registry.copy_to(ppOperatorRegistry);
return S_OK;
}
};