mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
# Motivation Currently, ORT minimal builds use kernel def hashes to map from nodes to kernels to execute when loading the model. As the kernel def hashes must be known ahead of time, this works for statically registered kernels. This works well for the CPU EP. For this approach to work, the kernel def hashes must also be known at ORT format model conversion time, which means the EP with statically registered kernels must also be enabled then. This is not an issue for the always-available CPU EP. However, we do not want to require that any EP which statically registers kernels is always available too. Consequently, we explore another approach to match nodes to kernels that does not rely on kernel def hashes. An added benefit of this is the possibility of moving away from kernel def hashes completely, which would eliminate the maintenance burden of keeping the hashes stable. # Approach In a full build, ORT uses some information from the ONNX op schema to match a node to a kernel. We want to avoid including the ONNX op schema in a minimal build to reduce binary size. Essentially, we take the necessary information from the ONNX op schema and make it available in a minimal build. We decouple the ONNX op schema from the kernel matching logic. The kernel matching logic instead relies on per-op information which can either be obtained from the ONNX op schema or another source. This per-op information must be available in a minimal build when there are no ONNX op schemas. We put it in the ORT format model. Existing uses of kernel def hashes to look up kernels are replaced with the updated kernel matching logic. We no longer store kernel def hashes in the ORT format model’s session state and runtime optimization representations. We no longer keep the logic to generate and ensure stability of kernel def hashes.
694 lines
No EOL
28 KiB
C++
694 lines
No EOL
28 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include "testPch.h"
|
|
#include <wil/result.h>
|
|
#include <D3d11_4.h>
|
|
#include <dxgi1_6.h>
|
|
#include "filehelpers.h"
|
|
#include <fstream>
|
|
#include <MemoryBuffer.h>
|
|
#include <gsl/gsl>
|
|
#include "CustomOperatorProvider.h"
|
|
#include "CustomOps.h"
|
|
|
|
// For custom operator and shape inferencing support
|
|
#include "core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h"
|
|
#include "core/providers/dml/DmlExecutionProvider/src/ErrorHandling.h"
|
|
#include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h"
|
|
#include "core/providers/dml/OperatorAuthorHelper/OperatorHelper.h"
|
|
#include "core/providers/dml/OperatorAuthorHelper/OperatorVersions.h"
|
|
#include "core/graph/constants.h"
|
|
#include "CustomNullOp.h"
|
|
#include <wil/wrl.h>
|
|
|
|
using namespace winml;
|
|
using namespace wfc;
|
|
using namespace wm;
|
|
using namespace wgi;
|
|
using namespace ws;
|
|
using namespace wss;
|
|
|
|
static void CustomOpsScenarioTestsClassSetup()
|
|
{
|
|
winrt::init_apartment();
|
|
#ifdef BUILD_INBOX
|
|
winrt_activation_handler = WINRT_RoGetActivationFactory;
|
|
#endif
|
|
}
|
|
|
|
// Tests that the execution provider correctly fuses operators together when custom ops are involved.
|
|
static void CustomOperatorFusion() {
|
|
constexpr const wchar_t* c_modelFilename = L"squeezenet_tensor_input.onnx";
|
|
|
|
// This particular model has 25 Conv ops and 25 Relu ops, all of which are eligible for fusion so we expect them
|
|
// all to be fused (removing them from the graph) and replaced with the appropriate fused op instead. The same
|
|
// goes for the single Gemm+Sigmoid in the model too.
|
|
constexpr const uint32_t c_expectedConvOps = 0;
|
|
constexpr const uint32_t c_expectedReluOps = 0;
|
|
constexpr const uint32_t c_expectedFusedConvOps = 25;
|
|
constexpr const uint32_t c_expectedGemmOps = 0;
|
|
constexpr const uint32_t c_expectedSigmoidOps = 0;
|
|
constexpr const uint32_t c_expectedFusedGemmOps = 1;
|
|
|
|
// These ops are also part of the model but shouldn't be fused
|
|
constexpr const uint32_t c_expectedBatchNormOps = 1;
|
|
constexpr const uint32_t c_expectedMaxPoolOps = 3;
|
|
constexpr const uint32_t c_expectedConcatOps = 8;
|
|
|
|
struct CallbackOperatorProvider : winrt::implements<
|
|
CallbackOperatorProvider,
|
|
winml::ILearningModelOperatorProvider,
|
|
ILearningModelOperatorProviderNative> {
|
|
struct CallCounts {
|
|
std::atomic<uint32_t> conv = 0;
|
|
std::atomic<uint32_t> relu = 0;
|
|
std::atomic<uint32_t> fusedConv = 0;
|
|
std::atomic<uint32_t> gemm = 0;
|
|
std::atomic<uint32_t> sigmoid = 0;
|
|
std::atomic<uint32_t> fusedGemm = 0;
|
|
std::atomic<uint32_t> batchNorm = 0;
|
|
std::atomic<uint32_t> maxPool = 0;
|
|
std::atomic<uint32_t> concat = 0;
|
|
};
|
|
|
|
const CallCounts& GetCallCounts() {
|
|
return m_callCounts;
|
|
}
|
|
|
|
CallbackOperatorProvider() {
|
|
using namespace OperatorHelper;
|
|
|
|
std::wostringstream dll;
|
|
dll << BINARY_NAME;
|
|
auto winml_dll_name = dll.str();
|
|
|
|
#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
|
|
auto m_library = LoadLibraryExW(winml_dll_name.c_str(), nullptr, 0);
|
|
#elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_PC_APP)
|
|
auto m_library = LoadPackagedLibrary(winml_dll_name.c_str(), 0 /*Reserved*/);
|
|
#endif
|
|
using create_registry_delegate = HRESULT WINAPI(_COM_Outptr_ IMLOperatorRegistry * *registry);
|
|
auto create_registry = reinterpret_cast<create_registry_delegate*>(GetProcAddress(m_library, "MLCreateOperatorRegistry"));
|
|
WINML_EXPECT_HRESULT_SUCCEEDED(create_registry(m_registry.put()));
|
|
|
|
#pragma push_macro("REGISTER_KERNEL")
|
|
#define REGISTER_KERNEL(_name, _domain, _opSet, _shapeInferrer, _callCount) \
|
|
NullOperatorFactory::RegisterKernel( \
|
|
#_name, \
|
|
(_domain), \
|
|
_opSet::sc_sinceVer_##_name, \
|
|
m_registry, \
|
|
winrt::make<NullShapeInferrer<_shapeInferrer>>(), \
|
|
(_callCount));
|
|
|
|
REGISTER_KERNEL(Conv, onnxruntime::kOnnxDomain, OnnxOperatorSet7, ConvHelper, &m_callCounts.conv);
|
|
REGISTER_KERNEL(Relu, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GetOutputShapeAsInputShapeHelper, &m_callCounts.relu);
|
|
REGISTER_KERNEL(DmlFusedConv, onnxruntime::kMSDmlDomain, MsftOperatorSet1, ConvHelper, &m_callCounts.fusedConv);
|
|
|
|
REGISTER_KERNEL(Gemm, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GemmHelper, &m_callCounts.gemm);
|
|
REGISTER_KERNEL(Sigmoid, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GetOutputShapeAsInputShapeHelper, &m_callCounts.sigmoid);
|
|
REGISTER_KERNEL(DmlFusedGemm, onnxruntime::kMSDmlDomain, MsftOperatorSet1, GemmHelper, &m_callCounts.fusedGemm);
|
|
|
|
REGISTER_KERNEL(BatchNormalization, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GetOutputShapeAsInputShapeHelper, &m_callCounts.batchNorm);
|
|
REGISTER_KERNEL(MaxPool, onnxruntime::kOnnxDomain, OnnxOperatorSet7, PoolingHelper, &m_callCounts.maxPool);
|
|
REGISTER_KERNEL(Concat, onnxruntime::kOnnxDomain, OnnxOperatorSet7, ConcatHelper, &m_callCounts.concat);
|
|
|
|
#pragma pop_macro("REGISTER_KERNEL")
|
|
}
|
|
|
|
STDMETHOD(GetRegistry)
|
|
(IMLOperatorRegistry** ppOperatorRegistry) {
|
|
if (ppOperatorRegistry == nullptr) {
|
|
return E_POINTER;
|
|
}
|
|
|
|
m_registry.copy_to(ppOperatorRegistry);
|
|
return S_OK;
|
|
}
|
|
|
|
private:
|
|
winrt::com_ptr<IMLOperatorRegistry> m_registry;
|
|
CallCounts m_callCounts;
|
|
};
|
|
|
|
auto customOperatorProvider = winrt::make<CallbackOperatorProvider>();
|
|
auto provider = customOperatorProvider.as<ILearningModelOperatorProvider>();
|
|
|
|
LearningModelDevice device = nullptr;
|
|
WINML_EXPECT_NO_THROW(device = LearningModelDevice(LearningModelDeviceKind::DirectX));
|
|
std::wstring fullPath = FileHelpers::GetModulePath() + c_modelFilename;
|
|
auto model = LearningModel::LoadFromFilePath(fullPath, provider);
|
|
|
|
auto featureValue = FileHelpers::LoadImageFeatureValue(L"227x227.png");
|
|
|
|
LearningModelSession session = nullptr;
|
|
WINML_EXPECT_NO_THROW(session = LearningModelSession(model, device));
|
|
LearningModelBinding modelBinding(session);
|
|
|
|
modelBinding.Bind(L"data", featureValue);
|
|
auto result = session.Evaluate(modelBinding, L"");
|
|
|
|
const auto& callCounts = customOperatorProvider.as<CallbackOperatorProvider>()->GetCallCounts();
|
|
|
|
// Verify that the correct number of each operator was seen (i.e. that none were dropped / incorrectly fused)
|
|
WINML_EXPECT_EQUAL(c_expectedConvOps, callCounts.conv);
|
|
WINML_EXPECT_EQUAL(c_expectedReluOps, callCounts.relu);
|
|
WINML_EXPECT_EQUAL(c_expectedFusedConvOps, callCounts.fusedConv);
|
|
WINML_EXPECT_EQUAL(c_expectedGemmOps, callCounts.gemm);
|
|
WINML_EXPECT_EQUAL(c_expectedSigmoidOps, callCounts.sigmoid);
|
|
WINML_EXPECT_EQUAL(c_expectedFusedGemmOps, callCounts.fusedGemm);
|
|
WINML_EXPECT_EQUAL(c_expectedBatchNormOps, callCounts.batchNorm);
|
|
WINML_EXPECT_EQUAL(c_expectedMaxPoolOps, callCounts.maxPool);
|
|
WINML_EXPECT_EQUAL(c_expectedConcatOps, callCounts.concat);
|
|
}
|
|
|
|
struct LocalCustomOperatorProvider : winrt::implements<
|
|
LocalCustomOperatorProvider,
|
|
winml::ILearningModelOperatorProvider,
|
|
ILearningModelOperatorProviderNative> {
|
|
LocalCustomOperatorProvider() {
|
|
|
|
std::wostringstream dll;
|
|
dll << BINARY_NAME;
|
|
auto winml_dll_name = dll.str();
|
|
|
|
#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
|
|
auto m_library = LoadLibraryExW(winml_dll_name.c_str(), nullptr, 0);
|
|
#elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_PC_APP)
|
|
auto m_library = LoadPackagedLibrary(winml_dll_name.c_str(), 0 /*Reserved*/);
|
|
#endif
|
|
using create_registry_delegate = HRESULT WINAPI(_COM_Outptr_ IMLOperatorRegistry * *registry);
|
|
auto create_registry = reinterpret_cast<create_registry_delegate*>(GetProcAddress(m_library, "MLCreateOperatorRegistry"));
|
|
WINML_EXPECT_HRESULT_SUCCEEDED(create_registry(m_registry.put()));
|
|
}
|
|
|
|
STDMETHOD(GetRegistry)
|
|
(IMLOperatorRegistry** ppOperatorRegistry) {
|
|
if (ppOperatorRegistry == nullptr) {
|
|
return E_POINTER;
|
|
}
|
|
|
|
m_registry.copy_to(ppOperatorRegistry);
|
|
return S_OK;
|
|
}
|
|
|
|
IMLOperatorRegistry* GetRegistry() {
|
|
return m_registry.get();
|
|
}
|
|
|
|
protected:
|
|
winrt::com_ptr<IMLOperatorRegistry> m_registry;
|
|
};
|
|
|
|
// Checks test attributes set on ABI kernels can be queried with correct values
|
|
void VerifyTestAttributes(const MLOperatorAttributes& attrs) {
|
|
std::string strAttr = attrs.GetAttribute("DefaultedNonRequiredString");
|
|
WINML_EXPECT_EQUAL(strAttr, "1");
|
|
|
|
std::vector<std::string> strArrayAttr = attrs.GetAttributeVector("DefaultedNonRequiredStringArray");
|
|
std::vector<std::string> expected = std::vector<std::string>({"1", "2"});
|
|
for (size_t i = 0; i < expected.size(); ++i) {
|
|
WINML_EXPECT_EQUAL(strArrayAttr[i], expected[i]);
|
|
}
|
|
|
|
WINML_EXPECT_EQUAL(1, attrs.GetAttribute<int64_t>("DefaultedNonRequiredInt"));
|
|
WINML_EXPECT_EQUAL(1.0f, attrs.GetAttribute<float>("DefaultedNonRequiredFloat"));
|
|
|
|
WINML_EXPECT_EQUAL(std::vector<int64_t>({1, 2}), attrs.GetAttributeVector<int64_t>("DefaultedNonRequiredIntArray"));
|
|
WINML_EXPECT_EQUAL(std::vector<float>({1.0f, 2.0f}), attrs.GetAttributeVector<float>("DefaultedNonRequiredFloatArray"));
|
|
}
|
|
|
|
// Foo kernel which is doing Add and optionally truncates its output
|
|
template <typename T, bool VerifyAttributes = false, bool Truncate = false>
|
|
class FooKernel {
|
|
public:
|
|
FooKernel(const MLOperatorKernelCreationContext& info) {
|
|
if (VerifyAttributes) {
|
|
VerifyTestAttributes(info);
|
|
}
|
|
|
|
VerifyShapeInfo(info);
|
|
}
|
|
|
|
void VerifyShapeInfo(const MLOperatorKernelCreationContext& info) {
|
|
if (!Truncate) {
|
|
winrt::com_ptr<IMLOperatorTensorShapeDescription> shapeInfo;
|
|
WINML_EXPECT_EQUAL(info.GetInterface()->HasTensorShapeDescription(), false);
|
|
WINML_EXPECT_HRESULT_FAILED(info.GetInterface()->GetTensorShapeDescription(shapeInfo.put()));
|
|
} else {
|
|
winrt::com_ptr<IMLOperatorTensorShapeDescription> shapeInfo;
|
|
WINML_EXPECT_EQUAL(info.GetInterface()->HasTensorShapeDescription(), true);
|
|
WINML_EXPECT_EQUAL(info.GetInterface()->GetTensorShapeDescription(shapeInfo.put()), S_OK);
|
|
}
|
|
}
|
|
|
|
void Compute(const MLOperatorKernelContext& context) const {
|
|
const auto X = context.GetInputTensor(0);
|
|
const auto W = context.GetInputTensor(1);
|
|
|
|
auto xData = X.GetData<T>();
|
|
auto wData = W.GetData<T>();
|
|
|
|
auto shape = X.GetShape();
|
|
|
|
// This is used to test shape inference
|
|
if (Truncate) {
|
|
shape[0] -= 1;
|
|
}
|
|
|
|
if (!Truncate) {
|
|
winrt::com_ptr<IMLOperatorTensor> tensor;
|
|
WINML_EXPECT_HRESULT_FAILED(context.GetInterface()->GetOutputTensor(0, tensor.put()));
|
|
} else {
|
|
MLOperatorTensor tensor = context.GetOutputTensor(0);
|
|
}
|
|
|
|
auto Y = context.GetOutputTensor(0, shape);
|
|
auto yData = Y.GetData<T>();
|
|
|
|
size_t size = 1;
|
|
for (size_t i = 0; i < shape.size(); i++) {
|
|
size *= shape[i];
|
|
}
|
|
|
|
for (size_t i = 0; i < size; i++) {
|
|
yData[i] = xData[i] + wData[i];
|
|
}
|
|
}
|
|
};
|
|
|
|
template <bool VerifyTestAttributes = false>
|
|
void CALLBACK CreateABIFooKernel(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel) {
|
|
HRESULT hr = MLOperatorKernel<FooKernel<float, VerifyTestAttributes>>::CreateInstance(*kernelInfo, opKernel);
|
|
THROW_IF_FAILED(hr);
|
|
}
|
|
|
|
void CALLBACK CreateTruncatedABIFooKernel(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel) {
|
|
HRESULT hr = MLOperatorKernel<FooKernel<float, true, true>>::CreateInstance(*kernelInfo, opKernel);
|
|
THROW_IF_FAILED(hr);
|
|
}
|
|
|
|
// Test using a foo kernel which is doing Add, but register it as "Mul".
|
|
static void CustomKernelWithBuiltInSchema() {
|
|
// Create the registry
|
|
auto operatorProvider = winrt::make<LocalCustomOperatorProvider>();
|
|
IMLOperatorRegistry* registry = operatorProvider.as<LocalCustomOperatorProvider>()->GetRegistry();
|
|
|
|
// Register the kernel
|
|
MLOperatorEdgeDescription floatTensorType =
|
|
{
|
|
MLOperatorEdgeType::Tensor,
|
|
static_cast<uint64_t>(MLOperatorTensorDataType::Float)};
|
|
|
|
MLOperatorEdgeTypeConstrant constraint = {"T", &floatTensorType, 1};
|
|
|
|
MLOperatorKernelDescription kernelDesc =
|
|
{
|
|
"",
|
|
"Mul",
|
|
7,
|
|
MLOperatorExecutionType::Cpu,
|
|
&constraint,
|
|
1,
|
|
nullptr,
|
|
0,
|
|
MLOperatorKernelOptions::AllowDynamicInputShapes};
|
|
|
|
Microsoft::WRL::ComPtr<MLOperatorKernelFactory> factory = wil::MakeOrThrow<MLOperatorKernelFactory>(CreateABIFooKernel<false>);
|
|
WINML_EXPECT_HRESULT_SUCCEEDED(registry->RegisterOperatorKernel(&kernelDesc, factory.Get(), nullptr));
|
|
|
|
// Prepare inputs
|
|
std::vector<int64_t> dimsX = {3, 2};
|
|
std::vector<float> valuesX = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
|
|
|
// Prepare expected inputs and outputs
|
|
std::vector<int64_t> expectedDimsY = {3, 2};
|
|
|
|
// The expected value should be Add's result.
|
|
std::vector<float> expectedValuesY = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f};
|
|
|
|
// Create the model and sessions
|
|
std::wstring fullPath = FileHelpers::GetModulePath() + L"mul.onnx";
|
|
LearningModel model = LearningModel::LoadFromFilePath(fullPath, operatorProvider);
|
|
|
|
LearningModelSession session(model);
|
|
LearningModelBinding bindings(session);
|
|
|
|
// Bind inputs and outputs
|
|
TensorFloat inputTensor = TensorFloat::CreateFromArray(dimsX, winrt::array_view<const float>(std::move(valuesX)));
|
|
bindings.Bind(winrt::hstring(L"X"), inputTensor);
|
|
|
|
auto outputValue = TensorFloat::Create();
|
|
WINML_EXPECT_NO_THROW(bindings.Bind(L"Y", outputValue));
|
|
|
|
// Evaluate the model
|
|
winrt::hstring correlationId;
|
|
WINML_EXPECT_NO_THROW(session.Evaluate(bindings, correlationId));
|
|
|
|
// Check the result shape
|
|
WINML_EXPECT_EQUAL(expectedDimsY.size(), outputValue.Shape().Size());
|
|
for (uint32_t j = 0; j < outputValue.Shape().Size(); j++) {
|
|
WINML_EXPECT_EQUAL(expectedDimsY.at(j), outputValue.Shape().GetAt(j));
|
|
}
|
|
|
|
// Check the results
|
|
auto buffer = outputValue.GetAsVectorView();
|
|
WINML_EXPECT_TRUE(buffer != nullptr);
|
|
WINML_EXPECT_TRUE(std::equal(expectedValuesY.cbegin(), expectedValuesY.cend(), begin(buffer)));
|
|
|
|
// Release the model before operatorProvider goes out of scope
|
|
model = nullptr;
|
|
}
|
|
|
|
// Similar to MLOperatorShapeInferrer, but using an std::function
|
|
class MLOperatorShapeInferrerFromFunc : public Microsoft::WRL::RuntimeClass<
|
|
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, IMLOperatorShapeInferrer> {
|
|
public:
|
|
MLOperatorShapeInferrerFromFunc(std::function<void(IMLOperatorShapeInferenceContext*)> shapeInferenceFn) : m_func(shapeInferenceFn) {}
|
|
|
|
HRESULT STDMETHODCALLTYPE InferOutputShapes(IMLOperatorShapeInferenceContext* context) noexcept override try {
|
|
m_func(context);
|
|
return S_OK;
|
|
}
|
|
CATCH_RETURN();
|
|
|
|
private:
|
|
std::function<void(IMLOperatorShapeInferenceContext*)> m_func;
|
|
};
|
|
|
|
// Test using a custom kernel and schema, while verifying attribute defaults, type mapping, and inference methods
|
|
static void CustomKernelWithCustomSchema() {
|
|
// Test cases
|
|
struct
|
|
{
|
|
// Whether the Foo kernel should truncate its output
|
|
bool truncateOutput;
|
|
|
|
// Whether a type label is used in the schema, versus a type description
|
|
bool useTypeLabel;
|
|
|
|
// Whether the schema provides a type inference function, and uses an output type
|
|
// of Int32 instead of Float32
|
|
bool useTypeInference;
|
|
|
|
// Whether a shape inference method is provided in the schema
|
|
bool useShapeInferenceInSchema;
|
|
|
|
// Whether a shape inference method is provided in the kernel
|
|
bool useShapeInferenceInKernel;
|
|
|
|
// Whether attribute defaults are provided in the schema, instead of the kernel
|
|
bool attributeDefaultsInSchema;
|
|
} testCases[] =
|
|
{
|
|
{false, true, false, false, false, false},
|
|
{false, false, false, false, false, false},
|
|
{false, true, true, false, false, true},
|
|
{true, false, false, false, true, false},
|
|
{true, true, true, true, true, true},
|
|
};
|
|
|
|
for (size_t caseIndex = 0; caseIndex < std::size(testCases); ++caseIndex) {
|
|
// Create the registry
|
|
auto operatorProvider = winrt::make<LocalCustomOperatorProvider>();
|
|
IMLOperatorRegistry* registry = operatorProvider.as<LocalCustomOperatorProvider>()->GetRegistry();
|
|
|
|
// Create input and output parameters
|
|
MLOperatorSchemaEdgeDescription inputParam = {};
|
|
inputParam.options = MLOperatorParameterOptions::Single;
|
|
|
|
if (!testCases[caseIndex].useTypeLabel) {
|
|
assert(!testCases[caseIndex].useTypeInference);
|
|
|
|
MLOperatorEdgeDescription edgeDesc = {};
|
|
edgeDesc.edgeType = MLOperatorEdgeType::Tensor;
|
|
edgeDesc.tensorDataType = MLOperatorTensorDataType::Float;
|
|
|
|
inputParam.typeFormat = MLOperatorSchemaEdgeTypeFormat::EdgeDescription;
|
|
inputParam.edgeDescription = edgeDesc;
|
|
} else {
|
|
inputParam.typeFormat = MLOperatorSchemaEdgeTypeFormat::Label;
|
|
inputParam.typeLabel = "T1";
|
|
}
|
|
|
|
MLOperatorSchemaEdgeDescription outputParam = inputParam;
|
|
|
|
// Type inference should set this to tensor(float) even though T2 is not matched
|
|
// on an input label
|
|
if (testCases[caseIndex].useTypeInference) {
|
|
if (inputParam.typeFormat == MLOperatorSchemaEdgeTypeFormat::Label) {
|
|
outputParam.typeLabel = "T2";
|
|
} else {
|
|
outputParam.edgeDescription.tensorDataType = MLOperatorTensorDataType::Int32;
|
|
}
|
|
}
|
|
|
|
MLOperatorSchemaEdgeDescription inputs[] = {inputParam, inputParam};
|
|
|
|
MLOperatorEdgeDescription edgeTypes[6] =
|
|
{
|
|
{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::UInt32)},
|
|
{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::UInt64)},
|
|
{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::Int32)},
|
|
{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::Int64)},
|
|
{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::Float)},
|
|
{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::Double)}};
|
|
|
|
// Type constraints. Only the first is used unless type inference is provided and
|
|
// the kernel emits a different output type as "T2"
|
|
MLOperatorEdgeTypeConstrant constraints[] =
|
|
{
|
|
{"T1", edgeTypes, static_cast<uint32_t>(std::size(edgeTypes))},
|
|
{"T2", edgeTypes, static_cast<uint32_t>(std::size(edgeTypes))}};
|
|
|
|
// Test attributes
|
|
MLOperatorAttribute attributes[] =
|
|
{
|
|
{"DefaultedNonRequiredInt", MLOperatorAttributeType::Int, false},
|
|
{"DefaultedNonRequiredFloat", MLOperatorAttributeType::Float, false},
|
|
{"DefaultedNonRequiredString", MLOperatorAttributeType::String, false},
|
|
{"DefaultedNonRequiredIntArray", MLOperatorAttributeType::IntArray, false},
|
|
{"DefaultedNonRequiredFloatArray", MLOperatorAttributeType::FloatArray, false},
|
|
{"DefaultedNonRequiredStringArray", MLOperatorAttributeType::StringArray, false},
|
|
|
|
{"NonDefaultedNonRequiredStringArray", MLOperatorAttributeType::StringArray, false},
|
|
};
|
|
|
|
// Defaults. These are queried back during kernel creation, type and shape inference
|
|
// and tested against the same values
|
|
MLOperatorAttributeNameValue defaultAttributes[] =
|
|
{
|
|
{"DefaultedNonRequiredInt", MLOperatorAttributeType::Int, 1},
|
|
{"DefaultedNonRequiredFloat", MLOperatorAttributeType::Float, 1},
|
|
{"DefaultedNonRequiredString", MLOperatorAttributeType::String, 1},
|
|
{"DefaultedNonRequiredIntArray", MLOperatorAttributeType::IntArray, 2},
|
|
{"DefaultedNonRequiredFloatArray", MLOperatorAttributeType::FloatArray, 2},
|
|
{"DefaultedNonRequiredStringArray", MLOperatorAttributeType::StringArray, 2},
|
|
};
|
|
|
|
int64_t defaultInts[] = {1, 2};
|
|
float defaultFloats[] = {1.0f, 2.0f};
|
|
const char* defaultStrings[] = {"1", "2"};
|
|
defaultAttributes[0].ints = defaultInts;
|
|
defaultAttributes[1].floats = defaultFloats;
|
|
defaultAttributes[2].strings = defaultStrings;
|
|
defaultAttributes[3].ints = defaultInts;
|
|
defaultAttributes[4].floats = defaultFloats;
|
|
defaultAttributes[5].strings = defaultStrings;
|
|
|
|
// Schema definition
|
|
MLOperatorSchemaDescription schemaDesc = {};
|
|
schemaDesc.name = "Foo";
|
|
schemaDesc.operatorSetVersionAtLastChange = 7;
|
|
schemaDesc.inputs = inputs;
|
|
schemaDesc.inputCount = 2;
|
|
schemaDesc.outputs = &outputParam;
|
|
schemaDesc.outputCount = 1;
|
|
schemaDesc.typeConstraints = constraints;
|
|
schemaDesc.typeConstraintCount = testCases[caseIndex].useTypeLabel ? 2 : 0;
|
|
schemaDesc.attributes = attributes;
|
|
schemaDesc.attributeCount = static_cast<uint32_t>(std::size(attributes));
|
|
|
|
if (testCases[caseIndex].attributeDefaultsInSchema) {
|
|
schemaDesc.defaultAttributes = defaultAttributes;
|
|
schemaDesc.defaultAttributeCount = static_cast<uint32_t>(std::size(defaultAttributes));
|
|
}
|
|
|
|
Microsoft::WRL::ComPtr<MLOperatorTypeInferrer> typeInferrer;
|
|
Microsoft::WRL::ComPtr<MLOperatorShapeInferrerFromFunc> shapeInferrer;
|
|
|
|
// Type inference function
|
|
if (testCases[caseIndex].useTypeInference) {
|
|
typeInferrer = wil::MakeOrThrow<MLOperatorTypeInferrer>([](IMLOperatorTypeInferenceContext* ctx) -> void {
|
|
VerifyTestAttributes(MLOperatorTypeInferenceContext(ctx));
|
|
|
|
MLOperatorEdgeDescription edgeDesc = {};
|
|
edgeDesc.edgeType = MLOperatorEdgeType::Tensor;
|
|
edgeDesc.tensorDataType = MLOperatorTensorDataType::Float;
|
|
|
|
MLOperatorTypeInferenceContext(ctx).SetOutputEdgeDescription(0, &edgeDesc);
|
|
});
|
|
}
|
|
|
|
// Store the shape inference context with a reference following the call to InferOutputShapes.
|
|
// This will be called after loading the model as an isolated test for how ABI context objects
|
|
// are "closed."
|
|
Microsoft::WRL::ComPtr<IMLOperatorShapeInferenceContext> shapeInferenceContext;
|
|
|
|
// Shape inference is tested by truncating the output size
|
|
bool truncateOutput = testCases[caseIndex].truncateOutput;
|
|
if (truncateOutput) {
|
|
shapeInferrer = wil::MakeOrThrow<MLOperatorShapeInferrerFromFunc>([&shapeInferenceContext](IMLOperatorShapeInferenceContext* ctx) -> void {
|
|
VerifyTestAttributes(MLShapeInferenceContext(ctx));
|
|
MLShapeInferenceContext(ctx).SetOutputTensorShape(0, {2, 2});
|
|
shapeInferenceContext = ctx;
|
|
});
|
|
}
|
|
|
|
// Register the schema
|
|
MLOperatorSetId opsetId = {"", 7};
|
|
MLOperatorSchemaDescription* opSchemaDescs = &schemaDesc;
|
|
WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorSetSchema(
|
|
&opsetId,
|
|
1,
|
|
&opSchemaDescs,
|
|
1,
|
|
typeInferrer.Get(),
|
|
testCases[caseIndex].useShapeInferenceInSchema ? shapeInferrer.Get() : nullptr));
|
|
|
|
{
|
|
// Register a future version of the schema in the same domain, while setting its
|
|
// input count to zero to ensure it is not being used.
|
|
auto futureSchemaDesc = schemaDesc;
|
|
futureSchemaDesc.inputCount = 0;
|
|
|
|
MLOperatorSetId id = {"", 9};
|
|
MLOperatorSchemaDescription* schemaDescs = &futureSchemaDesc;
|
|
WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorSetSchema(
|
|
&id,
|
|
7,
|
|
&schemaDescs,
|
|
1,
|
|
typeInferrer.Get(),
|
|
testCases[caseIndex].useShapeInferenceInSchema ? shapeInferrer.Get() : nullptr));
|
|
}
|
|
{
|
|
// Register in another (unused) domain to the custom registry
|
|
auto otherSchemaDesc = schemaDesc;
|
|
otherSchemaDesc.inputCount = 0;
|
|
|
|
MLOperatorSetId id = {"otherDomain", 7};
|
|
MLOperatorSchemaDescription* schemaDescs = &otherSchemaDesc;
|
|
WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorSetSchema(
|
|
&id,
|
|
1,
|
|
&schemaDescs,
|
|
1,
|
|
typeInferrer.Get(),
|
|
testCases[caseIndex].useShapeInferenceInSchema ? shapeInferrer.Get() : nullptr));
|
|
}
|
|
// Register the Foo kernel
|
|
MLOperatorEdgeDescription floatTensorEdgeDesc = {};
|
|
floatTensorEdgeDesc.edgeType = MLOperatorEdgeType::Tensor;
|
|
floatTensorEdgeDesc.tensorDataType = MLOperatorTensorDataType::Float;
|
|
|
|
MLOperatorEdgeTypeConstrant kernelConstraint = {"T1", &floatTensorEdgeDesc, 1};
|
|
|
|
MLOperatorKernelDescription kernelDesc =
|
|
{
|
|
"",
|
|
"Foo",
|
|
7,
|
|
MLOperatorExecutionType::Cpu,
|
|
&kernelConstraint,
|
|
testCases[caseIndex].useTypeLabel ? 1u : 0u};
|
|
|
|
if (!testCases[caseIndex].attributeDefaultsInSchema) {
|
|
kernelDesc.defaultAttributes = defaultAttributes;
|
|
kernelDesc.defaultAttributeCount = static_cast<uint32_t>(std::size(defaultAttributes));
|
|
}
|
|
|
|
if (!truncateOutput) {
|
|
kernelDesc.options = MLOperatorKernelOptions::AllowDynamicInputShapes;
|
|
Microsoft::WRL::ComPtr<MLOperatorKernelFactory> factory = wil::MakeOrThrow<MLOperatorKernelFactory>(CreateABIFooKernel<true>);
|
|
|
|
WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorKernel(&kernelDesc, factory.Get(), nullptr));
|
|
} else {
|
|
Microsoft::WRL::ComPtr<MLOperatorKernelFactory> factory = wil::MakeOrThrow<MLOperatorKernelFactory>(CreateTruncatedABIFooKernel);
|
|
WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorKernel(
|
|
&kernelDesc,
|
|
factory.Get(),
|
|
testCases[caseIndex].useShapeInferenceInKernel ? shapeInferrer.Get() : nullptr));
|
|
}
|
|
|
|
// Prepare inputs
|
|
std::vector<int64_t> dimsX = {3, 2};
|
|
std::vector<float> valuesX = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
|
|
|
// Prepare expected inputs and outputs
|
|
std::vector<int64_t> expectedDimsY = {truncateOutput ? 2 : 3, 2};
|
|
// now the expected value should be Add's result.
|
|
std::vector<float> expectedValuesY = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f};
|
|
if (truncateOutput) {
|
|
// The leading dimension is truncated, and the second dimension has two elements over that dim
|
|
expectedValuesY.resize(expectedValuesY.size() - 2);
|
|
}
|
|
|
|
// Load the model and sessions
|
|
std::wstring fullPath = FileHelpers::GetModulePath() + (truncateOutput ? L"foo_truncated.onnx" : L"foo.onnx");
|
|
LearningModel model = LearningModel::LoadFromFilePath(fullPath, operatorProvider);
|
|
LearningModelSession session(model);
|
|
|
|
// Bind input and outputs
|
|
LearningModelBinding bindings(session);
|
|
|
|
TensorFloat inputTensor = TensorFloat::CreateFromArray(dimsX, winrt::array_view<const float>(std::move(valuesX)));
|
|
bindings.Bind(winrt::hstring(L"X"), inputTensor);
|
|
|
|
auto outputValue = TensorFloat::Create();
|
|
WINML_EXPECT_NO_THROW(bindings.Bind(L"Y", outputValue));
|
|
|
|
// Evaluate the model
|
|
winrt::hstring correlationId;
|
|
WINML_EXPECT_NO_THROW(session.Evaluate(bindings, correlationId));
|
|
|
|
// Verify the result shape
|
|
WINML_EXPECT_EQUAL(expectedDimsY.size(), outputValue.Shape().Size());
|
|
for (uint32_t j = 0; j < outputValue.Shape().Size(); j++) {
|
|
WINML_EXPECT_EQUAL(expectedDimsY.at(j), outputValue.Shape().GetAt(j));
|
|
}
|
|
|
|
// Verify the result values
|
|
auto buffer = outputValue.GetAsVectorView();
|
|
WINML_EXPECT_TRUE(buffer != nullptr);
|
|
WINML_EXPECT_TRUE(std::equal(expectedValuesY.cbegin(), expectedValuesY.cend(), begin(buffer)));
|
|
|
|
// Release the model before operatorProvider goes out of scope
|
|
model = nullptr;
|
|
|
|
if (shapeInferenceContext) {
|
|
// Check that the shape inference context is closed and safely fails
|
|
MLOperatorEdgeDescription edgeDesc;
|
|
WINML_EXPECT_EQUAL(E_INVALIDARG, shapeInferenceContext->GetInputEdgeDescription(0, &edgeDesc));
|
|
}
|
|
}
|
|
}
|
|
|
|
const CustomOpsTestsApi& getapi() {
|
|
static CustomOpsTestsApi api =
|
|
{
|
|
CustomOpsScenarioTestsClassSetup,
|
|
CustomOperatorFusion,
|
|
CustomKernelWithBuiltInSchema,
|
|
CustomKernelWithCustomSchema
|
|
};
|
|
|
|
if (SkipGpuTests()) {
|
|
api.CustomOperatorFusion = SkipTest;
|
|
}
|
|
if (RuntimeParameterExists(L"noVideoFrameTests")) {
|
|
api.CustomOperatorFusion = SkipTest;
|
|
}
|
|
return api;
|
|
} |