onnxruntime/winml/test/scenario/cppwinrt/CustomOperatorProvider.h
Justin Chu 416dc2e84d
Fix clang-format comment indents on Windows for winml/ (#17144)
On Windows, clang-format has a bug when AlignTrailingComments.Kind is
set to `Leave`
(https://clang.llvm.org/docs/ClangFormatStyleOptions.html#aligntrailingcomments),
where it will keep adding indentation to comments after each formatting
runs.

This PR changes to always align comments so we do not hit the bug.

As a consequence of the options change we need to reformat some of the
files. Note that this option is aligned with the rest of the repository.
2023-08-14 23:50:14 -04:00

55 lines
1.7 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "NoisyReluCpu.h"
#include "ReluCpu.h"
#include <winnt.h>
struct CustomOperatorProvider
: winrt::
implements<CustomOperatorProvider, winml::ILearningModelOperatorProvider, ILearningModelOperatorProviderNative> {
HMODULE m_library;
winrt::com_ptr<IMLOperatorRegistry> m_registry;
CustomOperatorProvider() {
std::wostringstream dll;
dll << BINARY_NAME;
auto winml_dll_name = dll.str();
#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
m_library = LoadLibraryExW(winml_dll_name.c_str(), nullptr, 0);
#elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_PC_APP)
m_library = LoadPackagedLibrary(winml_dll_name.c_str(), 0 /*Reserved*/);
#endif
WINML_EXPECT_TRUE(m_library != 0);
using create_registry_delegate = HRESULT WINAPI(_COM_Outptr_ IMLOperatorRegistry * *registry);
auto create_registry =
reinterpret_cast<create_registry_delegate*>(GetProcAddress(m_library, "MLCreateOperatorRegistry"));
if (FAILED(create_registry(m_registry.put()))) {
__fastfail(0);
}
RegisterSchemas();
RegisterKernels();
}
~CustomOperatorProvider() { FreeLibrary(m_library); }
void RegisterSchemas() { NoisyReluOperatorFactory::RegisterNoisyReluSchema(m_registry); }
void RegisterKernels() {
// Replace the Relu operator kernel
ReluOperatorFactory::RegisterReluKernel(m_registry);
// Add a new operator kernel for Relu
NoisyReluOperatorFactory::RegisterNoisyReluKernel(m_registry);
}
STDMETHOD(GetRegistry)(IMLOperatorRegistry** ppOperatorRegistry) {
m_registry.copy_to(ppOperatorRegistry);
return S_OK;
}
};