onnxruntime/winml/test/scenario/cppwinrt/CustomOps.cpp
Ori Levari f564569a80
Adapter Model and Environment tests (#3469)
*Adapter Model and Environment tests
*winml test macro clean up and extension
2020-04-14 13:36:31 -07:00

736 lines
No EOL
28 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "testPch.h"
#include <wil/result.h>
#include <D3d11_4.h>
#include <dxgi1_6.h>
#include "filehelpers.h"
#include <fstream>
#include <winrt/Windows.Graphics.Imaging.h>
#include <winrt/Windows.Media.h>
#include "winrt/Windows.Storage.h"
#include <winrt/Windows.Storage.Streams.h>
#include <MemoryBuffer.h>
#include <gsl/gsl>
#include "CustomOperatorProvider.h"
#include "CustomOps.h"
// For custom operator and shape inferencing support
#include "core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h"
#include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h"
#include "core/providers/dml/OperatorAuthorHelper/OperatorHelper.h"
#include "core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h"
#include "core/graph/constants.h"
#include "CustomNullOp.h"
#include <wil/wrl.h>
using namespace winrt;
using namespace winrt::Windows::AI::MachineLearning;
using namespace winrt::Windows::Foundation::Collections;
using namespace winrt::Windows::Media;
using namespace winrt::Windows::Graphics::Imaging;
using namespace winrt::Windows::Storage;
using namespace winrt::Windows::Storage::Streams;
static void CustomOpsScenarioTestsClassSetup()
{
init_apartment();
}
static void CustomOpsScenarioTestsGpuMethodSetup()
{
GPUTEST;
}
// Tests that the execution provider correctly fuses operators together when custom ops are involved.
static void CustomOperatorFusion() {
constexpr const wchar_t* c_modelFilename = L"squeezenet_tensor_input.onnx";
// This particular model has 25 Conv ops and 25 Relu ops, all of which are eligible for fusion so we expect them
// all to be fused (removing them from the graph) and replaced with the appropriate fused op instead. The same
// goes for the single Gemm+Sigmoid in the model too.
constexpr const uint32_t c_expectedConvOps = 0;
constexpr const uint32_t c_expectedReluOps = 0;
constexpr const uint32_t c_expectedFusedConvOps = 25;
constexpr const uint32_t c_expectedGemmOps = 0;
constexpr const uint32_t c_expectedSigmoidOps = 0;
constexpr const uint32_t c_expectedFusedGemmOps = 1;
// These ops are also part of the model but shouldn't be fused
constexpr const uint32_t c_expectedBatchNormOps = 1;
constexpr const uint32_t c_expectedMaxPoolOps = 3;
constexpr const uint32_t c_expectedConcatOps = 8;
struct CallbackOperatorProvider :
winrt::implements<
CallbackOperatorProvider,
winrt::Windows::AI::MachineLearning::ILearningModelOperatorProvider,
ILearningModelOperatorProviderNative>
{
struct CallCounts
{
std::atomic<uint32_t> conv = 0;
std::atomic<uint32_t> relu = 0;
std::atomic<uint32_t> fusedConv = 0;
std::atomic<uint32_t> gemm = 0;
std::atomic<uint32_t> sigmoid = 0;
std::atomic<uint32_t> fusedGemm = 0;
std::atomic<uint32_t> batchNorm = 0;
std::atomic<uint32_t> maxPool = 0;
std::atomic<uint32_t> concat = 0;
};
const CallCounts& GetCallCounts()
{
return m_callCounts;
}
CallbackOperatorProvider()
{
using namespace OperatorHelper;
WINML_EXPECT_HRESULT_SUCCEEDED(MLCreateOperatorRegistry(m_registry.put()));
#pragma push_macro("REGISTER_KERNEL")
#define REGISTER_KERNEL(_name, _domain, _opSet, _shapeInferrer, _callCount) \
NullOperatorFactory::RegisterKernel( \
#_name, \
(_domain), \
_opSet::sc_sinceVer_ ## _name, \
m_registry, \
winrt::make<NullShapeInferrer<_shapeInferrer>>(), \
(_callCount));
REGISTER_KERNEL(Conv, onnxruntime::kOnnxDomain, OnnxOperatorSet7, ConvHelper, &m_callCounts.conv);
REGISTER_KERNEL(Relu, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GetOutputShapeAsInputShapeHelper, &m_callCounts.relu);
REGISTER_KERNEL(FusedConv, onnxruntime::kMSDmlDomain, MsftOperatorSet1, ConvHelper, &m_callCounts.fusedConv);
REGISTER_KERNEL(Gemm, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GemmHelper, &m_callCounts.gemm);
REGISTER_KERNEL(Sigmoid, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GetOutputShapeAsInputShapeHelper, &m_callCounts.sigmoid);
REGISTER_KERNEL(FusedGemm, onnxruntime::kMSDmlDomain, MsftOperatorSet1, GemmHelper, &m_callCounts.fusedGemm);
REGISTER_KERNEL(BatchNormalization, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GetOutputShapeAsInputShapeHelper, &m_callCounts.batchNorm);
REGISTER_KERNEL(MaxPool, onnxruntime::kOnnxDomain, OnnxOperatorSet7, PoolingHelper, &m_callCounts.maxPool);
REGISTER_KERNEL(Concat, onnxruntime::kOnnxDomain, OnnxOperatorSet7, ConcatHelper, &m_callCounts.concat);
#pragma pop_macro("REGISTER_KERNEL")
}
STDMETHOD(GetRegistry)(IMLOperatorRegistry** ppOperatorRegistry)
{
if (ppOperatorRegistry == nullptr)
{
return E_POINTER;
}
m_registry.copy_to(ppOperatorRegistry);
return S_OK;
}
private:
winrt::com_ptr<IMLOperatorRegistry> m_registry;
CallCounts m_callCounts;
};
auto customOperatorProvider = winrt::make<CallbackOperatorProvider>();
auto provider = customOperatorProvider.as<ILearningModelOperatorProvider>();
LearningModelDevice device = nullptr;
WINML_EXPECT_NO_THROW(device = LearningModelDevice(LearningModelDeviceKind::DirectX));
std::wstring fullPath = FileHelpers::GetModulePath() + c_modelFilename;
auto model = LearningModel::LoadFromFilePath(fullPath, provider);
auto featureValue = FileHelpers::LoadImageFeatureValue(L"227x227.png");
LearningModelSession session = nullptr;
WINML_EXPECT_NO_THROW(session = LearningModelSession(model, device));
LearningModelBinding modelBinding(session);
modelBinding.Bind(L"data", featureValue);
auto result = session.Evaluate(modelBinding, L"");
const auto& callCounts = customOperatorProvider.as<CallbackOperatorProvider>()->GetCallCounts();
// Verify that the correct number of each operator was seen (i.e. that none were dropped / incorrectly fused)
WINML_EXPECT_EQUAL(c_expectedConvOps, callCounts.conv);
WINML_EXPECT_EQUAL(c_expectedReluOps, callCounts.relu);
WINML_EXPECT_EQUAL(c_expectedFusedConvOps, callCounts.fusedConv);
WINML_EXPECT_EQUAL(c_expectedGemmOps, callCounts.gemm);
WINML_EXPECT_EQUAL(c_expectedSigmoidOps, callCounts.sigmoid);
WINML_EXPECT_EQUAL(c_expectedFusedGemmOps, callCounts.fusedGemm);
WINML_EXPECT_EQUAL(c_expectedBatchNormOps, callCounts.batchNorm);
WINML_EXPECT_EQUAL(c_expectedMaxPoolOps, callCounts.maxPool);
WINML_EXPECT_EQUAL(c_expectedConcatOps, callCounts.concat);
}
struct LocalCustomOperatorProvider :
winrt::implements<
LocalCustomOperatorProvider,
winrt::Windows::AI::MachineLearning::ILearningModelOperatorProvider,
ILearningModelOperatorProviderNative>
{
LocalCustomOperatorProvider()
{
WINML_EXPECT_HRESULT_SUCCEEDED(MLCreateOperatorRegistry(m_registry.put()));
}
STDMETHOD(GetRegistry)(IMLOperatorRegistry** ppOperatorRegistry)
{
if (ppOperatorRegistry == nullptr)
{
return E_POINTER;
}
m_registry.copy_to(ppOperatorRegistry);
return S_OK;
}
IMLOperatorRegistry* GetRegistry()
{
return m_registry.get();
}
protected:
winrt::com_ptr<IMLOperatorRegistry> m_registry;
};
// Checks test attributes set on ABI kernels can be queried with correct values
void VerifyTestAttributes(const MLOperatorAttributes& attrs)
{
std::string strAttr = attrs.GetAttribute("DefaultedNonRequiredString");
WINML_EXPECT_EQUAL(strAttr, "1");
std::vector<std::string> strArrayAttr = attrs.GetAttributeVector("DefaultedNonRequiredStringArray");
std::vector<std::string> expected = std::vector<std::string>({ "1", "2" });
for (size_t i = 0; i < expected.size(); ++i)
{
WINML_EXPECT_EQUAL(strArrayAttr[i], expected[i]);
}
WINML_EXPECT_EQUAL(1, attrs.GetAttribute<int64_t>("DefaultedNonRequiredInt"));
WINML_EXPECT_EQUAL(1.0f, attrs.GetAttribute<float>("DefaultedNonRequiredFloat"));
WINML_EXPECT_EQUAL(std::vector<int64_t>({ 1, 2 }), attrs.GetAttributeVector<int64_t>("DefaultedNonRequiredIntArray"));
WINML_EXPECT_EQUAL(std::vector<float>({ 1.0f, 2.0f }), attrs.GetAttributeVector<float>("DefaultedNonRequiredFloatArray"));
}
// Foo kernel which is doing Add and optionally truncates its output
template <typename T, bool VerifyAttributes = false, bool Truncate = false>
class FooKernel
{
public:
FooKernel(const MLOperatorKernelCreationContext& info)
{
if (VerifyAttributes)
{
VerifyTestAttributes(info);
}
VerifyShapeInfo(info);
}
void VerifyShapeInfo(const MLOperatorKernelCreationContext& info)
{
if (!Truncate)
{
com_ptr<IMLOperatorTensorShapeDescription> shapeInfo;
WINML_EXPECT_EQUAL(info.GetInterface()->HasTensorShapeDescription(), false);
WINML_EXPECT_HRESULT_FAILED(info.GetInterface()->GetTensorShapeDescription(shapeInfo.put()));
}
else
{
com_ptr<IMLOperatorTensorShapeDescription> shapeInfo;
WINML_EXPECT_EQUAL(info.GetInterface()->HasTensorShapeDescription(), true);
WINML_EXPECT_EQUAL(info.GetInterface()->GetTensorShapeDescription(shapeInfo.put()), S_OK);
}
}
void Compute(const MLOperatorKernelContext& context) const
{
const auto X = context.GetInputTensor(0);
const auto W = context.GetInputTensor(1);
auto xData = X.GetData<T>();
auto wData = W.GetData<T>();
auto shape = X.GetShape();
// This is used to test shape inference
if (Truncate)
{
shape[0] -= 1;
}
if (!Truncate)
{
com_ptr<IMLOperatorTensor> tensor;
WINML_EXPECT_HRESULT_FAILED(context.GetInterface()->GetOutputTensor(0, tensor.put()));
}
else
{
MLOperatorTensor tensor = context.GetOutputTensor(0);
}
auto Y = context.GetOutputTensor(0, shape);
auto yData = Y.GetData<T>();
size_t size = 1;
for (size_t i = 0; i < shape.size(); i++)
{
size *= shape[i];
}
for (size_t i = 0; i < size; i++)
{
yData[i] = xData[i] + wData[i];
}
}
};
template <bool VerifyTestAttributes = false>
void CALLBACK CreateABIFooKernel(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel)
{
HRESULT hr = MLOperatorKernel<FooKernel<float, VerifyTestAttributes>>::CreateInstance(*kernelInfo, opKernel);
THROW_IF_FAILED(hr);
}
void CALLBACK CreateTruncatedABIFooKernel(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel)
{
HRESULT hr = MLOperatorKernel<FooKernel<float, true, true>>::CreateInstance(*kernelInfo, opKernel);
THROW_IF_FAILED(hr);
}
// Test using a foo kernel which is doing Add, but register it as "Mul".
static void CustomKernelWithBuiltInSchema()
{
// Create the registry
auto operatorProvider = winrt::make<LocalCustomOperatorProvider>();
IMLOperatorRegistry* registry = operatorProvider.as<LocalCustomOperatorProvider>()->GetRegistry();
// Register the kernel
MLOperatorEdgeDescription floatTensorType =
{
MLOperatorEdgeType::Tensor,
static_cast<uint64_t>(MLOperatorTensorDataType::Float)
};
MLOperatorEdgeTypeConstrant constraint = { "T", &floatTensorType, 1 };
MLOperatorKernelDescription kernelDesc =
{
"",
"Mul",
7,
MLOperatorExecutionType::Cpu,
&constraint,
1,
nullptr,
0,
MLOperatorKernelOptions::AllowDynamicInputShapes
};
Microsoft::WRL::ComPtr<MLOperatorKernelFactory> factory = wil::MakeOrThrow<MLOperatorKernelFactory>(CreateABIFooKernel<false>);
WINML_EXPECT_HRESULT_SUCCEEDED(registry->RegisterOperatorKernel(&kernelDesc, factory.Get(), nullptr));
// Prepare inputs
std::vector<int64_t> dimsX = { 3, 2 };
std::vector<float> valuesX = { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f };
// Prepare expected inputs and outputs
std::vector<int64_t> expectedDimsY = { 3, 2 };
// The expected value should be Add's result.
std::vector<float> expectedValuesY = { 2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f };
// Create the model and sessions
std::wstring fullPath = FileHelpers::GetModulePath() + L"mul.onnx";
LearningModel model = LearningModel::LoadFromFilePath(fullPath, operatorProvider);
LearningModelSession session(model);
LearningModelBinding bindings(session);
// Bind inputs and outputs
TensorFloat inputTensor = TensorFloat::CreateFromArray(dimsX, winrt::array_view<const float>(std::move(valuesX)));
bindings.Bind(winrt::hstring(L"X"), inputTensor);
auto outputValue = TensorFloat::Create();
WINML_EXPECT_NO_THROW(bindings.Bind(L"Y", outputValue));
// Evaluate the model
hstring correlationId;
WINML_EXPECT_NO_THROW(session.Evaluate(bindings, correlationId));
// Check the result shape
WINML_EXPECT_EQUAL(expectedDimsY.size(), outputValue.Shape().Size());
for (uint32_t j = 0; j < outputValue.Shape().Size(); j++)
{
WINML_EXPECT_EQUAL(expectedDimsY.at(j), outputValue.Shape().GetAt(j));
}
// Check the results
auto buffer = outputValue.GetAsVectorView();
WINML_EXPECT_TRUE(buffer != nullptr);
WINML_EXPECT_TRUE(std::equal(expectedValuesY.cbegin(), expectedValuesY.cend(), begin(buffer)));
// Release the model before operatorProvider goes out of scope
model = nullptr;
}
// Similar to MLOperatorShapeInferrer, but using an std::function
class MLOperatorShapeInferrerFromFunc : public Microsoft::WRL::RuntimeClass<
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, IMLOperatorShapeInferrer>
{
public:
MLOperatorShapeInferrerFromFunc(std::function<void(IMLOperatorShapeInferenceContext*)> shapeInferenceFn) :
m_func(shapeInferenceFn)
{}
HRESULT STDMETHODCALLTYPE InferOutputShapes(IMLOperatorShapeInferenceContext* context) noexcept override try
{
m_func(context);
return S_OK;
}
CATCH_RETURN();
private:
std::function<void(IMLOperatorShapeInferenceContext*)> m_func;
};
// Test using a custom kernel and schema, while verifying attribute defaults, type mapping, and inference methods
static void CustomKernelWithCustomSchema()
{
// Test cases
struct
{
// Whether the Foo kernel should truncate its output
bool truncateOutput;
// Whether a type label is used in the schema, versus a type description
bool useTypeLabel;
// Whether the schema provides a type inference function, and uses an output type
// of Int32 instead of Float32
bool useTypeInference;
// Whether a shape inference method is provided in the schema
bool useShapeInferenceInSchema;
// Whether a shape inference method is provided in the kernel
bool useShapeInferenceInKernel;
// Whether attribute defaults are provided in the schema, instead of the kernel
bool attributeDefaultsInSchema;
} testCases[] =
{
{false, true, false, false, false, false},
{false, false, false, false, false, false},
{false, true, true, false, false, true},
{true, false, false, false, true, false},
{true, true, true, true, true, true},
};
for (size_t caseIndex = 0; caseIndex < std::size(testCases); ++caseIndex)
{
// Create the registry
auto operatorProvider = winrt::make<LocalCustomOperatorProvider>();
IMLOperatorRegistry* registry = operatorProvider.as<LocalCustomOperatorProvider>()->GetRegistry();
// Create input and output parameters
MLOperatorSchemaEdgeDescription inputParam = {};
inputParam.options = MLOperatorParameterOptions::Single;
if (!testCases[caseIndex].useTypeLabel)
{
assert(!testCases[caseIndex].useTypeInference);
MLOperatorEdgeDescription edgeDesc = {};
edgeDesc.edgeType = MLOperatorEdgeType::Tensor;
edgeDesc.tensorDataType = MLOperatorTensorDataType::Float;
inputParam.typeFormat = MLOperatorSchemaEdgeTypeFormat::EdgeDescription;
inputParam.edgeDescription = edgeDesc;
}
else
{
inputParam.typeFormat = MLOperatorSchemaEdgeTypeFormat::Label;
inputParam.typeLabel = "T1";
}
MLOperatorSchemaEdgeDescription outputParam = inputParam;
// Type inference should set this to tensor(float) even though T2 is not matched
// on an input label
if (testCases[caseIndex].useTypeInference)
{
if (inputParam.typeFormat == MLOperatorSchemaEdgeTypeFormat::Label)
{
outputParam.typeLabel = "T2";
}
else
{
outputParam.edgeDescription.tensorDataType = MLOperatorTensorDataType::Int32;
}
}
MLOperatorSchemaEdgeDescription inputs[] = { inputParam, inputParam };
MLOperatorEdgeDescription edgeTypes[6] =
{
{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::UInt32)},
{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::UInt64)},
{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::Int32)},
{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::Int64)},
{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::Float)},
{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::Double)}
};
// Type constraints. Only the first is used unless type inference is provided and
// the kernel emits a different output type as "T2"
MLOperatorEdgeTypeConstrant constraints[] =
{
{"T1", edgeTypes, static_cast<uint32_t>(std::size(edgeTypes))},
{"T2", edgeTypes, static_cast<uint32_t>(std::size(edgeTypes))}
};
// Test attributes
MLOperatorAttribute attributes[] =
{
{"DefaultedNonRequiredInt", MLOperatorAttributeType::Int, false},
{"DefaultedNonRequiredFloat", MLOperatorAttributeType::Float, false},
{"DefaultedNonRequiredString", MLOperatorAttributeType::String, false},
{"DefaultedNonRequiredIntArray", MLOperatorAttributeType::IntArray, false},
{"DefaultedNonRequiredFloatArray", MLOperatorAttributeType::FloatArray, false},
{"DefaultedNonRequiredStringArray", MLOperatorAttributeType::StringArray, false},
{"NonDefaultedNonRequiredStringArray", MLOperatorAttributeType::StringArray, false},
};
// Defaults. These are queried back during kernel creation, type and shape inference
// and tested against the same values
MLOperatorAttributeNameValue defaultAttributes[] =
{
{"DefaultedNonRequiredInt", MLOperatorAttributeType::Int, 1},
{"DefaultedNonRequiredFloat", MLOperatorAttributeType::Float, 1},
{"DefaultedNonRequiredString", MLOperatorAttributeType::String, 1},
{"DefaultedNonRequiredIntArray", MLOperatorAttributeType::IntArray, 2},
{"DefaultedNonRequiredFloatArray", MLOperatorAttributeType::FloatArray, 2},
{"DefaultedNonRequiredStringArray", MLOperatorAttributeType::StringArray, 2},
};
int64_t defaultInts[] = { 1, 2 };
float defaultFloats[] = { 1.0f, 2.0f };
const char* defaultStrings[] = { "1", "2" };
defaultAttributes[0].ints = defaultInts;
defaultAttributes[1].floats = defaultFloats;
defaultAttributes[2].strings = defaultStrings;
defaultAttributes[3].ints = defaultInts;
defaultAttributes[4].floats = defaultFloats;
defaultAttributes[5].strings = defaultStrings;
// Schema definition
MLOperatorSchemaDescription schemaDesc = {};
schemaDesc.name = "Foo";
schemaDesc.operatorSetVersionAtLastChange = 7;
schemaDesc.inputs = inputs;
schemaDesc.inputCount = 2;
schemaDesc.outputs = &outputParam;
schemaDesc.outputCount = 1;
schemaDesc.typeConstraints = constraints;
schemaDesc.typeConstraintCount = testCases[caseIndex].useTypeLabel ? 2 : 0;
schemaDesc.attributes = attributes;
schemaDesc.attributeCount = static_cast<uint32_t>(std::size(attributes));
if (testCases[caseIndex].attributeDefaultsInSchema)
{
schemaDesc.defaultAttributes = defaultAttributes;
schemaDesc.defaultAttributeCount = static_cast<uint32_t>(std::size(defaultAttributes));
}
Microsoft::WRL::ComPtr<MLOperatorTypeInferrer> typeInferrer;
Microsoft::WRL::ComPtr<MLOperatorShapeInferrerFromFunc> shapeInferrer;
// Type inference function
if (testCases[caseIndex].useTypeInference)
{
typeInferrer = wil::MakeOrThrow<MLOperatorTypeInferrer>([](IMLOperatorTypeInferenceContext* ctx) -> void
{
VerifyTestAttributes(MLOperatorTypeInferenceContext(ctx));
MLOperatorEdgeDescription edgeDesc = {};
edgeDesc.edgeType = MLOperatorEdgeType::Tensor;
edgeDesc.tensorDataType = MLOperatorTensorDataType::Float;
MLOperatorTypeInferenceContext(ctx).SetOutputEdgeDescription(0, &edgeDesc);
});
}
// Store the shape inference context with a reference following the call to InferOutputShapes.
// This will be called after loading the model as an isolated test for how ABI context objects
// are "closed."
Microsoft::WRL::ComPtr<IMLOperatorShapeInferenceContext> shapeInferenceContext;
// Shape inference is tested by truncating the output size
bool truncateOutput = testCases[caseIndex].truncateOutput;
if (truncateOutput)
{
shapeInferrer = wil::MakeOrThrow<MLOperatorShapeInferrerFromFunc>([&shapeInferenceContext](IMLOperatorShapeInferenceContext* ctx) -> void
{
VerifyTestAttributes(MLShapeInferenceContext(ctx));
MLShapeInferenceContext(ctx).SetOutputTensorShape(0, { 2, 2 });
shapeInferenceContext = ctx;
});
}
// Register the schema
MLOperatorSetId opsetId = { "", 7 };
MLOperatorSchemaDescription* opSchemaDescs = &schemaDesc;
WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorSetSchema(
&opsetId,
1,
&opSchemaDescs,
1,
typeInferrer.Get(),
testCases[caseIndex].useShapeInferenceInSchema ? shapeInferrer.Get() : nullptr
));
{
// Register a future version of the schema in the same domain, while setting its
// input count to zero to ensure it is not being used.
auto futureSchemaDesc = schemaDesc;
futureSchemaDesc.inputCount = 0;
MLOperatorSetId id = { "", 9 };
MLOperatorSchemaDescription* schemaDescs = &futureSchemaDesc;
WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorSetSchema(
&id,
7,
&schemaDescs,
1,
typeInferrer.Get(),
testCases[caseIndex].useShapeInferenceInSchema ? shapeInferrer.Get() : nullptr
));
}
{
// Register in another (unused) domain to the custom registry
auto otherSchemaDesc = schemaDesc;
otherSchemaDesc.inputCount = 0;
MLOperatorSetId id = { "otherDomain", 7 };
MLOperatorSchemaDescription* schemaDescs = &otherSchemaDesc;
WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorSetSchema(
&id,
1,
&schemaDescs,
1,
typeInferrer.Get(),
testCases[caseIndex].useShapeInferenceInSchema ? shapeInferrer.Get() : nullptr
));
}
// Register the Foo kernel
MLOperatorEdgeDescription floatTensorEdgeDesc = {};
floatTensorEdgeDesc.edgeType = MLOperatorEdgeType::Tensor;
floatTensorEdgeDesc.tensorDataType = MLOperatorTensorDataType::Float;
MLOperatorEdgeTypeConstrant kernelConstraint = { "T", &floatTensorEdgeDesc, 1 };
MLOperatorKernelDescription kernelDesc =
{
"",
"Foo",
7,
MLOperatorExecutionType::Cpu,
&kernelConstraint,
1
};
if (!testCases[caseIndex].attributeDefaultsInSchema)
{
kernelDesc.defaultAttributes = defaultAttributes;
kernelDesc.defaultAttributeCount = static_cast<uint32_t>(std::size(defaultAttributes));
}
if (!truncateOutput)
{
kernelDesc.options = MLOperatorKernelOptions::AllowDynamicInputShapes;
Microsoft::WRL::ComPtr<MLOperatorKernelFactory> factory = wil::MakeOrThrow<MLOperatorKernelFactory>(CreateABIFooKernel<true>);
WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorKernel(&kernelDesc, factory.Get(), nullptr));
}
else
{
Microsoft::WRL::ComPtr<MLOperatorKernelFactory> factory = wil::MakeOrThrow<MLOperatorKernelFactory>(CreateTruncatedABIFooKernel);
WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorKernel(
&kernelDesc,
factory.Get(),
testCases[caseIndex].useShapeInferenceInKernel ? shapeInferrer.Get() : nullptr
));
}
// Prepare inputs
std::vector<int64_t> dimsX = { 3, 2 };
std::vector<float> valuesX = { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f };
// Prepare expected inputs and outputs
std::vector<int64_t> expectedDimsY = { truncateOutput ? 2 : 3, 2 };
// now the expected value should be Add's result.
std::vector<float> expectedValuesY = { 2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f };
if (truncateOutput)
{
// The leading dimension is truncated, and the second dimension has two elements over that dim
expectedValuesY.resize(expectedValuesY.size() - 2);
}
// Load the model and sessions
std::wstring fullPath = FileHelpers::GetModulePath() + (truncateOutput ? L"foo_truncated.onnx" : L"foo.onnx");
LearningModel model = LearningModel::LoadFromFilePath(fullPath, operatorProvider);
LearningModelSession session(model);
// Bind input and outputs
LearningModelBinding bindings(session);
TensorFloat inputTensor = TensorFloat::CreateFromArray(dimsX, winrt::array_view<const float>(std::move(valuesX)));
bindings.Bind(winrt::hstring(L"X"), inputTensor);
auto outputValue = TensorFloat::Create();
WINML_EXPECT_NO_THROW(bindings.Bind(L"Y", outputValue));
// Evaluate the model
hstring correlationId;
WINML_EXPECT_NO_THROW(session.Evaluate(bindings, correlationId));
// Verify the result shape
WINML_EXPECT_EQUAL(expectedDimsY.size(), outputValue.Shape().Size());
for (uint32_t j = 0; j < outputValue.Shape().Size(); j++)
{
WINML_EXPECT_EQUAL(expectedDimsY.at(j), outputValue.Shape().GetAt(j));
}
// Verify the result values
auto buffer = outputValue.GetAsVectorView();
WINML_EXPECT_TRUE(buffer != nullptr);
WINML_EXPECT_TRUE(std::equal(expectedValuesY.cbegin(), expectedValuesY.cend(), begin(buffer)));
// Release the model before operatorProvider goes out of scope
model = nullptr;
if (shapeInferenceContext)
{
// Check that the shape inference context is closed and safely fails
MLOperatorEdgeDescription edgeDesc;
WINML_EXPECT_EQUAL(E_INVALIDARG, shapeInferenceContext->GetInputEdgeDescription(0, &edgeDesc));
}
}
}
const CustomOpsTestsApi& getapi() {
static constexpr CustomOpsTestsApi api =
{
CustomOpsScenarioTestsClassSetup,
CustomOpsScenarioTestsGpuMethodSetup,
CustomOperatorFusion,
CustomKernelWithBuiltInSchema,
CustomKernelWithCustomSchema
};
return api;
}