mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
67 lines
No EOL
1.8 KiB
C++
67 lines
No EOL
1.8 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#pragma once
|
|
|
|
#include "NoisyReluCpu.h"
|
|
#include "ReluCpu.h"
|
|
#include <winnt.h>
|
|
|
|
struct CustomOperatorProvider :
|
|
winrt::implements<
|
|
CustomOperatorProvider,
|
|
winml::ILearningModelOperatorProvider,
|
|
ILearningModelOperatorProviderNative>
|
|
{
|
|
HMODULE m_library;
|
|
winrt::com_ptr<IMLOperatorRegistry> m_registry;
|
|
|
|
CustomOperatorProvider()
|
|
{
|
|
std::wostringstream dll;
|
|
dll << BINARY_NAME;
|
|
auto winml_dll_name = dll.str();
|
|
|
|
#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
|
|
m_library = LoadLibraryExW(winml_dll_name.c_str(), nullptr, 0);
|
|
#elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_PC_APP)
|
|
m_library = LoadPackagedLibrary(winml_dll_name.c_str(), 0 /*Reserved*/);
|
|
#endif
|
|
WINML_EXPECT_TRUE(m_library != 0);
|
|
|
|
using create_registry_delegate = HRESULT WINAPI (_COM_Outptr_ IMLOperatorRegistry** registry);
|
|
auto create_registry = reinterpret_cast<create_registry_delegate*>(GetProcAddress(m_library, "MLCreateOperatorRegistry"));
|
|
if (FAILED(create_registry(m_registry.put())))
|
|
{
|
|
__fastfail(0);
|
|
}
|
|
|
|
RegisterSchemas();
|
|
RegisterKernels();
|
|
}
|
|
|
|
~CustomOperatorProvider()
|
|
{
|
|
FreeLibrary(m_library);
|
|
}
|
|
|
|
void RegisterSchemas()
|
|
{
|
|
NoisyReluOperatorFactory::RegisterNoisyReluSchema(m_registry);
|
|
}
|
|
|
|
void RegisterKernels()
|
|
{
|
|
// Replace the Relu operator kernel
|
|
ReluOperatorFactory::RegisterReluKernel(m_registry);
|
|
|
|
// Add a new operator kernel for Relu
|
|
NoisyReluOperatorFactory::RegisterNoisyReluKernel(m_registry);
|
|
}
|
|
|
|
STDMETHOD(GetRegistry)(IMLOperatorRegistry** ppOperatorRegistry)
|
|
{
|
|
m_registry.copy_to(ppOperatorRegistry);
|
|
return S_OK;
|
|
}
|
|
}; |