onnxruntime/winml/test/scenario/cppwinrt/CustomOperatorProvider.h
Sheil Kumar 2717c178cc
Fork the WinML APIs into the Microsoft namespace (#3503)
* Migrate winml to Microsoft Namespace (packaging changes are pending)

* add ns_prefix toggle

* fix packaging

* Users/sheilk/add missing raw header (#3484)

* add dualapipartition

* wrong variable for repo root

Co-authored-by: Sheil Kumar <sheilk@microsoft.com>

* remove existence check to force failures

* extra paren

* dualapipartition needs to be referenced from the source

* add microsoft.ai.machinelearning.dll to the output dir

* rename the idl file so that assembly info is correctly added into the winmd

* fix namespaces

* update namespaces

* default to microsoft, and add namespace override as build argument

* update cmakesetings.json as well

* remove from cmakelists.txt

Co-authored-by: Sheil Kumar <sheilk@microsoft.com>
Co-authored-by: Changming Sun <chasun@microsoft.com>
2020-04-17 06:18:54 -07:00

64 lines
No EOL
1.8 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "NoisyReluCpu.h"
#include "ReluCpu.h"
struct CustomOperatorProvider :
winrt::implements<
CustomOperatorProvider,
winml::ILearningModelOperatorProvider,
ILearningModelOperatorProviderNative>
{
HMODULE m_library;
winrt::com_ptr<IMLOperatorRegistry> m_registry;
CustomOperatorProvider()
{
std::wostringstream dll;
dll << BINARY_NAME;
auto winml_dll_name = dll.str();
#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
m_library = LoadLibraryW(winml_dll_name.c_str());
#elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_PC_APP)
m_library = LoadPackagedLibrary(winml_dll_name.c_str(), 0 /*Reserved*/);
#endif
using create_registry_delegate = HRESULT WINAPI (_COM_Outptr_ IMLOperatorRegistry** registry);
auto create_registry = reinterpret_cast<create_registry_delegate*>(GetProcAddress(m_library, "MLCreateOperatorRegistry"));
if (FAILED(create_registry(m_registry.put())))
{
__fastfail(0);
}
RegisterSchemas();
RegisterKernels();
}
~CustomOperatorProvider()
{
FreeLibrary(m_library);
}
void RegisterSchemas()
{
NoisyReluOperatorFactory::RegisterNoisyReluSchema(m_registry);
}
void RegisterKernels()
{
// Replace the Relu operator kernel
ReluOperatorFactory::RegisterReluKernel(m_registry);
// Add a new operator kernel for Relu
NoisyReluOperatorFactory::RegisterNoisyReluKernel(m_registry);
}
STDMETHOD(GetRegistry)(IMLOperatorRegistry** ppOperatorRegistry)
{
m_registry.copy_to(ppOperatorRegistry);
return S_OK;
}
};