mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
60 lines
1.7 KiB
C
60 lines
1.7 KiB
C
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||
|
|
// Licensed under the MIT License.
|
||
|
|
|
||
|
|
#pragma once
|
||
|
|
|
||
|
|
#include "NoisyReluCpu.h"
|
||
|
|
#include "ReluCpu.h"
|
||
|
|
|
||
|
|
struct CustomOperatorProvider :
|
||
|
|
winrt::implements<
|
||
|
|
CustomOperatorProvider,
|
||
|
|
winrt::Windows::AI::MachineLearning::ILearningModelOperatorProvider,
|
||
|
|
ILearningModelOperatorProviderNative>
|
||
|
|
{
|
||
|
|
HMODULE m_library;
|
||
|
|
winrt::com_ptr<IMLOperatorRegistry> m_registry;
|
||
|
|
|
||
|
|
CustomOperatorProvider()
|
||
|
|
{
|
||
|
|
#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
|
||
|
|
m_library = LoadLibraryW(L"windows.ai.machinelearning.dll");
|
||
|
|
#elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_PC_APP)
|
||
|
|
m_library = LoadPackagedLibrary(L"windows.ai.machinelearning.dll", 0 /*Reserved*/);
|
||
|
|
#endif
|
||
|
|
using create_registry_delegate = HRESULT WINAPI (_COM_Outptr_ IMLOperatorRegistry** registry);
|
||
|
|
auto create_registry = reinterpret_cast<create_registry_delegate*>(GetProcAddress(m_library, "MLCreateOperatorRegistry"));
|
||
|
|
if (FAILED(create_registry(m_registry.put())))
|
||
|
|
{
|
||
|
|
__fastfail(0);
|
||
|
|
}
|
||
|
|
|
||
|
|
RegisterSchemas();
|
||
|
|
RegisterKernels();
|
||
|
|
}
|
||
|
|
|
||
|
|
~CustomOperatorProvider()
|
||
|
|
{
|
||
|
|
FreeLibrary(m_library);
|
||
|
|
}
|
||
|
|
|
||
|
|
void RegisterSchemas()
|
||
|
|
{
|
||
|
|
NoisyReluOperatorFactory::RegisterNoisyReluSchema(m_registry);
|
||
|
|
}
|
||
|
|
|
||
|
|
void RegisterKernels()
|
||
|
|
{
|
||
|
|
// Replace the Relu operator kernel
|
||
|
|
ReluOperatorFactory::RegisterReluKernel(m_registry);
|
||
|
|
|
||
|
|
// Add a new operator kernel for Relu
|
||
|
|
NoisyReluOperatorFactory::RegisterNoisyReluKernel(m_registry);
|
||
|
|
}
|
||
|
|
|
||
|
|
STDMETHOD(GetRegistry)(IMLOperatorRegistry** ppOperatorRegistry)
|
||
|
|
{
|
||
|
|
m_registry.copy_to(ppOperatorRegistry);
|
||
|
|
return S_OK;
|
||
|
|
}
|
||
|
|
};
|