mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Extend opset 13 support for: - Split-13 - Squeeze-13 - Unsqueeze-13 - Reshape-13 - QuantizeLinear-13 - DequantizeLinear-13 - ReduceSum-13 - Resize-13 Also: - Rename the file where all the opset versions are stored from "OperatorRegistration.h" to "OperatorVersions.h", which will make it much less confusing in the future when looking given there's another file called "OperatorRegistration.h" that corresponds to "OperatorRegistration.cpp". - Detemplatize many of the OperatorHelper.h constructors, which duplicate multiple instantiations due to the operator helper classes not sharing a common base class, by wrapping them with an adapter. Ideally there would be a common COM base interface that both IMLOperatorKernelCreationContext and IMLOperatorShapeInferenceContext implementation objects would implement, which a wrapper in MLOperatorAuthorHelper.h could QI for. - Fix style formatting issues in OperatorHelper.h (sorry for the noise). ``` Summary: Total=4679, Passed=4355, Failed=0, Blocked=0, Not Run=0, Skipped=324 ``` Corresponding WindowsAI PR: https://microsoft.visualstudio.com/WindowsAI/_git/WindowsAI/pullrequest/6973645 Related work items: #36672908, #36672926
694 lines
No EOL
28 KiB
C++
694 lines
No EOL
28 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include "testPch.h"
|
|
#include <wil/result.h>
|
|
#include <D3d11_4.h>
|
|
#include <dxgi1_6.h>
|
|
#include "filehelpers.h"
|
|
#include <fstream>
|
|
#include <MemoryBuffer.h>
|
|
#include <gsl/gsl>
|
|
#include "CustomOperatorProvider.h"
|
|
#include "CustomOps.h"
|
|
|
|
// For custom operator and shape inferencing support
|
|
#include "core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h"
|
|
#include "core/providers/dml/DmlExecutionProvider/src/ErrorHandling.h"
|
|
#include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h"
|
|
#include "core/providers/dml/OperatorAuthorHelper/OperatorHelper.h"
|
|
#include "core/providers/dml/OperatorAuthorHelper/OperatorVersions.h"
|
|
#include "core/graph/constants.h"
|
|
#include "CustomNullOp.h"
|
|
#include <wil/wrl.h>
|
|
|
|
using namespace winml;
|
|
using namespace wfc;
|
|
using namespace wm;
|
|
using namespace wgi;
|
|
using namespace ws;
|
|
using namespace wss;
|
|
|
|
static void CustomOpsScenarioTestsClassSetup()
|
|
{
|
|
winrt::init_apartment();
|
|
#ifdef BUILD_INBOX
|
|
winrt_activation_handler = WINRT_RoGetActivationFactory;
|
|
#endif
|
|
}
|
|
|
|
// Tests that the execution provider correctly fuses operators together when custom ops are involved.
|
|
static void CustomOperatorFusion() {
|
|
constexpr const wchar_t* c_modelFilename = L"squeezenet_tensor_input.onnx";
|
|
|
|
// This particular model has 25 Conv ops and 25 Relu ops, all of which are eligible for fusion so we expect them
|
|
// all to be fused (removing them from the graph) and replaced with the appropriate fused op instead. The same
|
|
// goes for the single Gemm+Sigmoid in the model too.
|
|
constexpr const uint32_t c_expectedConvOps = 0;
|
|
constexpr const uint32_t c_expectedReluOps = 0;
|
|
constexpr const uint32_t c_expectedFusedConvOps = 25;
|
|
constexpr const uint32_t c_expectedGemmOps = 0;
|
|
constexpr const uint32_t c_expectedSigmoidOps = 0;
|
|
constexpr const uint32_t c_expectedFusedGemmOps = 1;
|
|
|
|
// These ops are also part of the model but shouldn't be fused
|
|
constexpr const uint32_t c_expectedBatchNormOps = 1;
|
|
constexpr const uint32_t c_expectedMaxPoolOps = 3;
|
|
constexpr const uint32_t c_expectedConcatOps = 8;
|
|
|
|
struct CallbackOperatorProvider : winrt::implements<
|
|
CallbackOperatorProvider,
|
|
winml::ILearningModelOperatorProvider,
|
|
ILearningModelOperatorProviderNative> {
|
|
struct CallCounts {
|
|
std::atomic<uint32_t> conv = 0;
|
|
std::atomic<uint32_t> relu = 0;
|
|
std::atomic<uint32_t> fusedConv = 0;
|
|
std::atomic<uint32_t> gemm = 0;
|
|
std::atomic<uint32_t> sigmoid = 0;
|
|
std::atomic<uint32_t> fusedGemm = 0;
|
|
std::atomic<uint32_t> batchNorm = 0;
|
|
std::atomic<uint32_t> maxPool = 0;
|
|
std::atomic<uint32_t> concat = 0;
|
|
};
|
|
|
|
const CallCounts& GetCallCounts() {
|
|
return m_callCounts;
|
|
}
|
|
|
|
CallbackOperatorProvider() {
|
|
using namespace OperatorHelper;
|
|
|
|
std::wostringstream dll;
|
|
dll << BINARY_NAME;
|
|
auto winml_dll_name = dll.str();
|
|
|
|
#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
|
|
auto m_library = LoadLibraryExW(winml_dll_name.c_str(), nullptr, 0);
|
|
#elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_PC_APP)
|
|
auto m_library = LoadPackagedLibrary(winml_dll_name.c_str(), 0 /*Reserved*/);
|
|
#endif
|
|
using create_registry_delegate = HRESULT WINAPI(_COM_Outptr_ IMLOperatorRegistry * *registry);
|
|
auto create_registry = reinterpret_cast<create_registry_delegate*>(GetProcAddress(m_library, "MLCreateOperatorRegistry"));
|
|
WINML_EXPECT_HRESULT_SUCCEEDED(create_registry(m_registry.put()));
|
|
|
|
#pragma push_macro("REGISTER_KERNEL")
|
|
#define REGISTER_KERNEL(_name, _domain, _opSet, _shapeInferrer, _callCount) \
|
|
NullOperatorFactory::RegisterKernel( \
|
|
#_name, \
|
|
(_domain), \
|
|
_opSet::sc_sinceVer_##_name, \
|
|
m_registry, \
|
|
winrt::make<NullShapeInferrer<_shapeInferrer>>(), \
|
|
(_callCount));
|
|
|
|
REGISTER_KERNEL(Conv, onnxruntime::kOnnxDomain, OnnxOperatorSet7, ConvHelper, &m_callCounts.conv);
|
|
REGISTER_KERNEL(Relu, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GetOutputShapeAsInputShapeHelper, &m_callCounts.relu);
|
|
REGISTER_KERNEL(FusedConv, onnxruntime::kMSDmlDomain, MsftOperatorSet1, ConvHelper, &m_callCounts.fusedConv);
|
|
|
|
REGISTER_KERNEL(Gemm, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GemmHelper, &m_callCounts.gemm);
|
|
REGISTER_KERNEL(Sigmoid, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GetOutputShapeAsInputShapeHelper, &m_callCounts.sigmoid);
|
|
REGISTER_KERNEL(FusedGemm, onnxruntime::kMSDmlDomain, MsftOperatorSet1, GemmHelper, &m_callCounts.fusedGemm);
|
|
|
|
REGISTER_KERNEL(BatchNormalization, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GetOutputShapeAsInputShapeHelper, &m_callCounts.batchNorm);
|
|
REGISTER_KERNEL(MaxPool, onnxruntime::kOnnxDomain, OnnxOperatorSet7, PoolingHelper, &m_callCounts.maxPool);
|
|
REGISTER_KERNEL(Concat, onnxruntime::kOnnxDomain, OnnxOperatorSet7, ConcatHelper, &m_callCounts.concat);
|
|
|
|
#pragma pop_macro("REGISTER_KERNEL")
|
|
}
|
|
|
|
STDMETHOD(GetRegistry)
|
|
(IMLOperatorRegistry** ppOperatorRegistry) {
|
|
if (ppOperatorRegistry == nullptr) {
|
|
return E_POINTER;
|
|
}
|
|
|
|
m_registry.copy_to(ppOperatorRegistry);
|
|
return S_OK;
|
|
}
|
|
|
|
private:
|
|
winrt::com_ptr<IMLOperatorRegistry> m_registry;
|
|
CallCounts m_callCounts;
|
|
};
|
|
|
|
auto customOperatorProvider = winrt::make<CallbackOperatorProvider>();
|
|
auto provider = customOperatorProvider.as<ILearningModelOperatorProvider>();
|
|
|
|
LearningModelDevice device = nullptr;
|
|
WINML_EXPECT_NO_THROW(device = LearningModelDevice(LearningModelDeviceKind::DirectX));
|
|
std::wstring fullPath = FileHelpers::GetModulePath() + c_modelFilename;
|
|
auto model = LearningModel::LoadFromFilePath(fullPath, provider);
|
|
|
|
auto featureValue = FileHelpers::LoadImageFeatureValue(L"227x227.png");
|
|
|
|
LearningModelSession session = nullptr;
|
|
WINML_EXPECT_NO_THROW(session = LearningModelSession(model, device));
|
|
LearningModelBinding modelBinding(session);
|
|
|
|
modelBinding.Bind(L"data", featureValue);
|
|
auto result = session.Evaluate(modelBinding, L"");
|
|
|
|
const auto& callCounts = customOperatorProvider.as<CallbackOperatorProvider>()->GetCallCounts();
|
|
|
|
// Verify that the correct number of each operator was seen (i.e. that none were dropped / incorrectly fused)
|
|
WINML_EXPECT_EQUAL(c_expectedConvOps, callCounts.conv);
|
|
WINML_EXPECT_EQUAL(c_expectedReluOps, callCounts.relu);
|
|
WINML_EXPECT_EQUAL(c_expectedFusedConvOps, callCounts.fusedConv);
|
|
WINML_EXPECT_EQUAL(c_expectedGemmOps, callCounts.gemm);
|
|
WINML_EXPECT_EQUAL(c_expectedSigmoidOps, callCounts.sigmoid);
|
|
WINML_EXPECT_EQUAL(c_expectedFusedGemmOps, callCounts.fusedGemm);
|
|
WINML_EXPECT_EQUAL(c_expectedBatchNormOps, callCounts.batchNorm);
|
|
WINML_EXPECT_EQUAL(c_expectedMaxPoolOps, callCounts.maxPool);
|
|
WINML_EXPECT_EQUAL(c_expectedConcatOps, callCounts.concat);
|
|
}
|
|
|
|
struct LocalCustomOperatorProvider : winrt::implements<
|
|
LocalCustomOperatorProvider,
|
|
winml::ILearningModelOperatorProvider,
|
|
ILearningModelOperatorProviderNative> {
|
|
LocalCustomOperatorProvider() {
|
|
|
|
std::wostringstream dll;
|
|
dll << BINARY_NAME;
|
|
auto winml_dll_name = dll.str();
|
|
|
|
#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
|
|
auto m_library = LoadLibraryExW(winml_dll_name.c_str(), nullptr, 0);
|
|
#elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_PC_APP)
|
|
auto m_library = LoadPackagedLibrary(winml_dll_name.c_str(), 0 /*Reserved*/);
|
|
#endif
|
|
using create_registry_delegate = HRESULT WINAPI(_COM_Outptr_ IMLOperatorRegistry * *registry);
|
|
auto create_registry = reinterpret_cast<create_registry_delegate*>(GetProcAddress(m_library, "MLCreateOperatorRegistry"));
|
|
WINML_EXPECT_HRESULT_SUCCEEDED(create_registry(m_registry.put()));
|
|
}
|
|
|
|
STDMETHOD(GetRegistry)
|
|
(IMLOperatorRegistry** ppOperatorRegistry) {
|
|
if (ppOperatorRegistry == nullptr) {
|
|
return E_POINTER;
|
|
}
|
|
|
|
m_registry.copy_to(ppOperatorRegistry);
|
|
return S_OK;
|
|
}
|
|
|
|
IMLOperatorRegistry* GetRegistry() {
|
|
return m_registry.get();
|
|
}
|
|
|
|
protected:
|
|
winrt::com_ptr<IMLOperatorRegistry> m_registry;
|
|
};
|
|
|
|
// Checks test attributes set on ABI kernels can be queried with correct values
|
|
void VerifyTestAttributes(const MLOperatorAttributes& attrs) {
|
|
std::string strAttr = attrs.GetAttribute("DefaultedNonRequiredString");
|
|
WINML_EXPECT_EQUAL(strAttr, "1");
|
|
|
|
std::vector<std::string> strArrayAttr = attrs.GetAttributeVector("DefaultedNonRequiredStringArray");
|
|
std::vector<std::string> expected = std::vector<std::string>({"1", "2"});
|
|
for (size_t i = 0; i < expected.size(); ++i) {
|
|
WINML_EXPECT_EQUAL(strArrayAttr[i], expected[i]);
|
|
}
|
|
|
|
WINML_EXPECT_EQUAL(1, attrs.GetAttribute<int64_t>("DefaultedNonRequiredInt"));
|
|
WINML_EXPECT_EQUAL(1.0f, attrs.GetAttribute<float>("DefaultedNonRequiredFloat"));
|
|
|
|
WINML_EXPECT_EQUAL(std::vector<int64_t>({1, 2}), attrs.GetAttributeVector<int64_t>("DefaultedNonRequiredIntArray"));
|
|
WINML_EXPECT_EQUAL(std::vector<float>({1.0f, 2.0f}), attrs.GetAttributeVector<float>("DefaultedNonRequiredFloatArray"));
|
|
}
|
|
|
|
// Foo kernel which is doing Add and optionally truncates its output
|
|
template <typename T, bool VerifyAttributes = false, bool Truncate = false>
|
|
class FooKernel {
|
|
public:
|
|
FooKernel(const MLOperatorKernelCreationContext& info) {
|
|
if (VerifyAttributes) {
|
|
VerifyTestAttributes(info);
|
|
}
|
|
|
|
VerifyShapeInfo(info);
|
|
}
|
|
|
|
void VerifyShapeInfo(const MLOperatorKernelCreationContext& info) {
|
|
if (!Truncate) {
|
|
winrt::com_ptr<IMLOperatorTensorShapeDescription> shapeInfo;
|
|
WINML_EXPECT_EQUAL(info.GetInterface()->HasTensorShapeDescription(), false);
|
|
WINML_EXPECT_HRESULT_FAILED(info.GetInterface()->GetTensorShapeDescription(shapeInfo.put()));
|
|
} else {
|
|
winrt::com_ptr<IMLOperatorTensorShapeDescription> shapeInfo;
|
|
WINML_EXPECT_EQUAL(info.GetInterface()->HasTensorShapeDescription(), true);
|
|
WINML_EXPECT_EQUAL(info.GetInterface()->GetTensorShapeDescription(shapeInfo.put()), S_OK);
|
|
}
|
|
}
|
|
|
|
void Compute(const MLOperatorKernelContext& context) const {
|
|
const auto X = context.GetInputTensor(0);
|
|
const auto W = context.GetInputTensor(1);
|
|
|
|
auto xData = X.GetData<T>();
|
|
auto wData = W.GetData<T>();
|
|
|
|
auto shape = X.GetShape();
|
|
|
|
// This is used to test shape inference
|
|
if (Truncate) {
|
|
shape[0] -= 1;
|
|
}
|
|
|
|
if (!Truncate) {
|
|
winrt::com_ptr<IMLOperatorTensor> tensor;
|
|
WINML_EXPECT_HRESULT_FAILED(context.GetInterface()->GetOutputTensor(0, tensor.put()));
|
|
} else {
|
|
MLOperatorTensor tensor = context.GetOutputTensor(0);
|
|
}
|
|
|
|
auto Y = context.GetOutputTensor(0, shape);
|
|
auto yData = Y.GetData<T>();
|
|
|
|
size_t size = 1;
|
|
for (size_t i = 0; i < shape.size(); i++) {
|
|
size *= shape[i];
|
|
}
|
|
|
|
for (size_t i = 0; i < size; i++) {
|
|
yData[i] = xData[i] + wData[i];
|
|
}
|
|
}
|
|
};
|
|
|
|
template <bool VerifyTestAttributes = false>
|
|
void CALLBACK CreateABIFooKernel(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel) {
|
|
HRESULT hr = MLOperatorKernel<FooKernel<float, VerifyTestAttributes>>::CreateInstance(*kernelInfo, opKernel);
|
|
THROW_IF_FAILED(hr);
|
|
}
|
|
|
|
void CALLBACK CreateTruncatedABIFooKernel(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel) {
|
|
HRESULT hr = MLOperatorKernel<FooKernel<float, true, true>>::CreateInstance(*kernelInfo, opKernel);
|
|
THROW_IF_FAILED(hr);
|
|
}
|
|
|
|
// Test using a foo kernel which is doing Add, but register it as "Mul".
|
|
static void CustomKernelWithBuiltInSchema() {
|
|
// Create the registry
|
|
auto operatorProvider = winrt::make<LocalCustomOperatorProvider>();
|
|
IMLOperatorRegistry* registry = operatorProvider.as<LocalCustomOperatorProvider>()->GetRegistry();
|
|
|
|
// Register the kernel
|
|
MLOperatorEdgeDescription floatTensorType =
|
|
{
|
|
MLOperatorEdgeType::Tensor,
|
|
static_cast<uint64_t>(MLOperatorTensorDataType::Float)};
|
|
|
|
MLOperatorEdgeTypeConstrant constraint = {"T", &floatTensorType, 1};
|
|
|
|
MLOperatorKernelDescription kernelDesc =
|
|
{
|
|
"",
|
|
"Mul",
|
|
7,
|
|
MLOperatorExecutionType::Cpu,
|
|
&constraint,
|
|
1,
|
|
nullptr,
|
|
0,
|
|
MLOperatorKernelOptions::AllowDynamicInputShapes};
|
|
|
|
Microsoft::WRL::ComPtr<MLOperatorKernelFactory> factory = wil::MakeOrThrow<MLOperatorKernelFactory>(CreateABIFooKernel<false>);
|
|
WINML_EXPECT_HRESULT_SUCCEEDED(registry->RegisterOperatorKernel(&kernelDesc, factory.Get(), nullptr));
|
|
|
|
// Prepare inputs
|
|
std::vector<int64_t> dimsX = {3, 2};
|
|
std::vector<float> valuesX = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
|
|
|
// Prepare expected inputs and outputs
|
|
std::vector<int64_t> expectedDimsY = {3, 2};
|
|
|
|
// The expected value should be Add's result.
|
|
std::vector<float> expectedValuesY = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f};
|
|
|
|
// Create the model and sessions
|
|
std::wstring fullPath = FileHelpers::GetModulePath() + L"mul.onnx";
|
|
LearningModel model = LearningModel::LoadFromFilePath(fullPath, operatorProvider);
|
|
|
|
LearningModelSession session(model);
|
|
LearningModelBinding bindings(session);
|
|
|
|
// Bind inputs and outputs
|
|
TensorFloat inputTensor = TensorFloat::CreateFromArray(dimsX, winrt::array_view<const float>(std::move(valuesX)));
|
|
bindings.Bind(winrt::hstring(L"X"), inputTensor);
|
|
|
|
auto outputValue = TensorFloat::Create();
|
|
WINML_EXPECT_NO_THROW(bindings.Bind(L"Y", outputValue));
|
|
|
|
// Evaluate the model
|
|
winrt::hstring correlationId;
|
|
WINML_EXPECT_NO_THROW(session.Evaluate(bindings, correlationId));
|
|
|
|
// Check the result shape
|
|
WINML_EXPECT_EQUAL(expectedDimsY.size(), outputValue.Shape().Size());
|
|
for (uint32_t j = 0; j < outputValue.Shape().Size(); j++) {
|
|
WINML_EXPECT_EQUAL(expectedDimsY.at(j), outputValue.Shape().GetAt(j));
|
|
}
|
|
|
|
// Check the results
|
|
auto buffer = outputValue.GetAsVectorView();
|
|
WINML_EXPECT_TRUE(buffer != nullptr);
|
|
WINML_EXPECT_TRUE(std::equal(expectedValuesY.cbegin(), expectedValuesY.cend(), begin(buffer)));
|
|
|
|
// Release the model before operatorProvider goes out of scope
|
|
model = nullptr;
|
|
}
|
|
|
|
// Similar to MLOperatorShapeInferrer, but using an std::function
|
|
class MLOperatorShapeInferrerFromFunc : public Microsoft::WRL::RuntimeClass<
|
|
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, IMLOperatorShapeInferrer> {
|
|
public:
|
|
MLOperatorShapeInferrerFromFunc(std::function<void(IMLOperatorShapeInferenceContext*)> shapeInferenceFn) : m_func(shapeInferenceFn) {}
|
|
|
|
HRESULT STDMETHODCALLTYPE InferOutputShapes(IMLOperatorShapeInferenceContext* context) noexcept override try {
|
|
m_func(context);
|
|
return S_OK;
|
|
}
|
|
CATCH_RETURN();
|
|
|
|
private:
|
|
std::function<void(IMLOperatorShapeInferenceContext*)> m_func;
|
|
};
|
|
|
|
// Test using a custom kernel and schema, while verifying attribute defaults, type mapping, and inference methods
|
|
static void CustomKernelWithCustomSchema() {
|
|
// Test cases
|
|
struct
|
|
{
|
|
// Whether the Foo kernel should truncate its output
|
|
bool truncateOutput;
|
|
|
|
// Whether a type label is used in the schema, versus a type description
|
|
bool useTypeLabel;
|
|
|
|
// Whether the schema provides a type inference function, and uses an output type
|
|
// of Int32 instead of Float32
|
|
bool useTypeInference;
|
|
|
|
// Whether a shape inference method is provided in the schema
|
|
bool useShapeInferenceInSchema;
|
|
|
|
// Whether a shape inference method is provided in the kernel
|
|
bool useShapeInferenceInKernel;
|
|
|
|
// Whether attribute defaults are provided in the schema, instead of the kernel
|
|
bool attributeDefaultsInSchema;
|
|
} testCases[] =
|
|
{
|
|
{false, true, false, false, false, false},
|
|
{false, false, false, false, false, false},
|
|
{false, true, true, false, false, true},
|
|
{true, false, false, false, true, false},
|
|
{true, true, true, true, true, true},
|
|
};
|
|
|
|
for (size_t caseIndex = 0; caseIndex < std::size(testCases); ++caseIndex) {
|
|
// Create the registry
|
|
auto operatorProvider = winrt::make<LocalCustomOperatorProvider>();
|
|
IMLOperatorRegistry* registry = operatorProvider.as<LocalCustomOperatorProvider>()->GetRegistry();
|
|
|
|
// Create input and output parameters
|
|
MLOperatorSchemaEdgeDescription inputParam = {};
|
|
inputParam.options = MLOperatorParameterOptions::Single;
|
|
|
|
if (!testCases[caseIndex].useTypeLabel) {
|
|
assert(!testCases[caseIndex].useTypeInference);
|
|
|
|
MLOperatorEdgeDescription edgeDesc = {};
|
|
edgeDesc.edgeType = MLOperatorEdgeType::Tensor;
|
|
edgeDesc.tensorDataType = MLOperatorTensorDataType::Float;
|
|
|
|
inputParam.typeFormat = MLOperatorSchemaEdgeTypeFormat::EdgeDescription;
|
|
inputParam.edgeDescription = edgeDesc;
|
|
} else {
|
|
inputParam.typeFormat = MLOperatorSchemaEdgeTypeFormat::Label;
|
|
inputParam.typeLabel = "T1";
|
|
}
|
|
|
|
MLOperatorSchemaEdgeDescription outputParam = inputParam;
|
|
|
|
// Type inference should set this to tensor(float) even though T2 is not matched
|
|
// on an input label
|
|
if (testCases[caseIndex].useTypeInference) {
|
|
if (inputParam.typeFormat == MLOperatorSchemaEdgeTypeFormat::Label) {
|
|
outputParam.typeLabel = "T2";
|
|
} else {
|
|
outputParam.edgeDescription.tensorDataType = MLOperatorTensorDataType::Int32;
|
|
}
|
|
}
|
|
|
|
MLOperatorSchemaEdgeDescription inputs[] = {inputParam, inputParam};
|
|
|
|
MLOperatorEdgeDescription edgeTypes[6] =
|
|
{
|
|
{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::UInt32)},
|
|
{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::UInt64)},
|
|
{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::Int32)},
|
|
{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::Int64)},
|
|
{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::Float)},
|
|
{MLOperatorEdgeType::Tensor, static_cast<uint64_t>(MLOperatorTensorDataType::Double)}};
|
|
|
|
// Type constraints. Only the first is used unless type inference is provided and
|
|
// the kernel emits a different output type as "T2"
|
|
MLOperatorEdgeTypeConstrant constraints[] =
|
|
{
|
|
{"T1", edgeTypes, static_cast<uint32_t>(std::size(edgeTypes))},
|
|
{"T2", edgeTypes, static_cast<uint32_t>(std::size(edgeTypes))}};
|
|
|
|
// Test attributes
|
|
MLOperatorAttribute attributes[] =
|
|
{
|
|
{"DefaultedNonRequiredInt", MLOperatorAttributeType::Int, false},
|
|
{"DefaultedNonRequiredFloat", MLOperatorAttributeType::Float, false},
|
|
{"DefaultedNonRequiredString", MLOperatorAttributeType::String, false},
|
|
{"DefaultedNonRequiredIntArray", MLOperatorAttributeType::IntArray, false},
|
|
{"DefaultedNonRequiredFloatArray", MLOperatorAttributeType::FloatArray, false},
|
|
{"DefaultedNonRequiredStringArray", MLOperatorAttributeType::StringArray, false},
|
|
|
|
{"NonDefaultedNonRequiredStringArray", MLOperatorAttributeType::StringArray, false},
|
|
};
|
|
|
|
// Defaults. These are queried back during kernel creation, type and shape inference
|
|
// and tested against the same values
|
|
MLOperatorAttributeNameValue defaultAttributes[] =
|
|
{
|
|
{"DefaultedNonRequiredInt", MLOperatorAttributeType::Int, 1},
|
|
{"DefaultedNonRequiredFloat", MLOperatorAttributeType::Float, 1},
|
|
{"DefaultedNonRequiredString", MLOperatorAttributeType::String, 1},
|
|
{"DefaultedNonRequiredIntArray", MLOperatorAttributeType::IntArray, 2},
|
|
{"DefaultedNonRequiredFloatArray", MLOperatorAttributeType::FloatArray, 2},
|
|
{"DefaultedNonRequiredStringArray", MLOperatorAttributeType::StringArray, 2},
|
|
};
|
|
|
|
int64_t defaultInts[] = {1, 2};
|
|
float defaultFloats[] = {1.0f, 2.0f};
|
|
const char* defaultStrings[] = {"1", "2"};
|
|
defaultAttributes[0].ints = defaultInts;
|
|
defaultAttributes[1].floats = defaultFloats;
|
|
defaultAttributes[2].strings = defaultStrings;
|
|
defaultAttributes[3].ints = defaultInts;
|
|
defaultAttributes[4].floats = defaultFloats;
|
|
defaultAttributes[5].strings = defaultStrings;
|
|
|
|
// Schema definition
|
|
MLOperatorSchemaDescription schemaDesc = {};
|
|
schemaDesc.name = "Foo";
|
|
schemaDesc.operatorSetVersionAtLastChange = 7;
|
|
schemaDesc.inputs = inputs;
|
|
schemaDesc.inputCount = 2;
|
|
schemaDesc.outputs = &outputParam;
|
|
schemaDesc.outputCount = 1;
|
|
schemaDesc.typeConstraints = constraints;
|
|
schemaDesc.typeConstraintCount = testCases[caseIndex].useTypeLabel ? 2 : 0;
|
|
schemaDesc.attributes = attributes;
|
|
schemaDesc.attributeCount = static_cast<uint32_t>(std::size(attributes));
|
|
|
|
if (testCases[caseIndex].attributeDefaultsInSchema) {
|
|
schemaDesc.defaultAttributes = defaultAttributes;
|
|
schemaDesc.defaultAttributeCount = static_cast<uint32_t>(std::size(defaultAttributes));
|
|
}
|
|
|
|
Microsoft::WRL::ComPtr<MLOperatorTypeInferrer> typeInferrer;
|
|
Microsoft::WRL::ComPtr<MLOperatorShapeInferrerFromFunc> shapeInferrer;
|
|
|
|
// Type inference function
|
|
if (testCases[caseIndex].useTypeInference) {
|
|
typeInferrer = wil::MakeOrThrow<MLOperatorTypeInferrer>([](IMLOperatorTypeInferenceContext* ctx) -> void {
|
|
VerifyTestAttributes(MLOperatorTypeInferenceContext(ctx));
|
|
|
|
MLOperatorEdgeDescription edgeDesc = {};
|
|
edgeDesc.edgeType = MLOperatorEdgeType::Tensor;
|
|
edgeDesc.tensorDataType = MLOperatorTensorDataType::Float;
|
|
|
|
MLOperatorTypeInferenceContext(ctx).SetOutputEdgeDescription(0, &edgeDesc);
|
|
});
|
|
}
|
|
|
|
// Store the shape inference context with a reference following the call to InferOutputShapes.
|
|
// This will be called after loading the model as an isolated test for how ABI context objects
|
|
// are "closed."
|
|
Microsoft::WRL::ComPtr<IMLOperatorShapeInferenceContext> shapeInferenceContext;
|
|
|
|
// Shape inference is tested by truncating the output size
|
|
bool truncateOutput = testCases[caseIndex].truncateOutput;
|
|
if (truncateOutput) {
|
|
shapeInferrer = wil::MakeOrThrow<MLOperatorShapeInferrerFromFunc>([&shapeInferenceContext](IMLOperatorShapeInferenceContext* ctx) -> void {
|
|
VerifyTestAttributes(MLShapeInferenceContext(ctx));
|
|
MLShapeInferenceContext(ctx).SetOutputTensorShape(0, {2, 2});
|
|
shapeInferenceContext = ctx;
|
|
});
|
|
}
|
|
|
|
// Register the schema
|
|
MLOperatorSetId opsetId = {"", 7};
|
|
MLOperatorSchemaDescription* opSchemaDescs = &schemaDesc;
|
|
WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorSetSchema(
|
|
&opsetId,
|
|
1,
|
|
&opSchemaDescs,
|
|
1,
|
|
typeInferrer.Get(),
|
|
testCases[caseIndex].useShapeInferenceInSchema ? shapeInferrer.Get() : nullptr));
|
|
|
|
{
|
|
// Register a future version of the schema in the same domain, while setting its
|
|
// input count to zero to ensure it is not being used.
|
|
auto futureSchemaDesc = schemaDesc;
|
|
futureSchemaDesc.inputCount = 0;
|
|
|
|
MLOperatorSetId id = {"", 9};
|
|
MLOperatorSchemaDescription* schemaDescs = &futureSchemaDesc;
|
|
WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorSetSchema(
|
|
&id,
|
|
7,
|
|
&schemaDescs,
|
|
1,
|
|
typeInferrer.Get(),
|
|
testCases[caseIndex].useShapeInferenceInSchema ? shapeInferrer.Get() : nullptr));
|
|
}
|
|
{
|
|
// Register in another (unused) domain to the custom registry
|
|
auto otherSchemaDesc = schemaDesc;
|
|
otherSchemaDesc.inputCount = 0;
|
|
|
|
MLOperatorSetId id = {"otherDomain", 7};
|
|
MLOperatorSchemaDescription* schemaDescs = &otherSchemaDesc;
|
|
WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorSetSchema(
|
|
&id,
|
|
1,
|
|
&schemaDescs,
|
|
1,
|
|
typeInferrer.Get(),
|
|
testCases[caseIndex].useShapeInferenceInSchema ? shapeInferrer.Get() : nullptr));
|
|
}
|
|
// Register the Foo kernel
|
|
MLOperatorEdgeDescription floatTensorEdgeDesc = {};
|
|
floatTensorEdgeDesc.edgeType = MLOperatorEdgeType::Tensor;
|
|
floatTensorEdgeDesc.tensorDataType = MLOperatorTensorDataType::Float;
|
|
|
|
MLOperatorEdgeTypeConstrant kernelConstraint = {"T", &floatTensorEdgeDesc, 1};
|
|
|
|
MLOperatorKernelDescription kernelDesc =
|
|
{
|
|
"",
|
|
"Foo",
|
|
7,
|
|
MLOperatorExecutionType::Cpu,
|
|
&kernelConstraint,
|
|
1};
|
|
|
|
if (!testCases[caseIndex].attributeDefaultsInSchema) {
|
|
kernelDesc.defaultAttributes = defaultAttributes;
|
|
kernelDesc.defaultAttributeCount = static_cast<uint32_t>(std::size(defaultAttributes));
|
|
}
|
|
|
|
if (!truncateOutput) {
|
|
kernelDesc.options = MLOperatorKernelOptions::AllowDynamicInputShapes;
|
|
Microsoft::WRL::ComPtr<MLOperatorKernelFactory> factory = wil::MakeOrThrow<MLOperatorKernelFactory>(CreateABIFooKernel<true>);
|
|
|
|
WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorKernel(&kernelDesc, factory.Get(), nullptr));
|
|
} else {
|
|
Microsoft::WRL::ComPtr<MLOperatorKernelFactory> factory = wil::MakeOrThrow<MLOperatorKernelFactory>(CreateTruncatedABIFooKernel);
|
|
WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorKernel(
|
|
&kernelDesc,
|
|
factory.Get(),
|
|
testCases[caseIndex].useShapeInferenceInKernel ? shapeInferrer.Get() : nullptr));
|
|
}
|
|
|
|
// Prepare inputs
|
|
std::vector<int64_t> dimsX = {3, 2};
|
|
std::vector<float> valuesX = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
|
|
|
// Prepare expected inputs and outputs
|
|
std::vector<int64_t> expectedDimsY = {truncateOutput ? 2 : 3, 2};
|
|
// now the expected value should be Add's result.
|
|
std::vector<float> expectedValuesY = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f};
|
|
if (truncateOutput) {
|
|
// The leading dimension is truncated, and the second dimension has two elements over that dim
|
|
expectedValuesY.resize(expectedValuesY.size() - 2);
|
|
}
|
|
|
|
// Load the model and sessions
|
|
std::wstring fullPath = FileHelpers::GetModulePath() + (truncateOutput ? L"foo_truncated.onnx" : L"foo.onnx");
|
|
LearningModel model = LearningModel::LoadFromFilePath(fullPath, operatorProvider);
|
|
LearningModelSession session(model);
|
|
|
|
// Bind input and outputs
|
|
LearningModelBinding bindings(session);
|
|
|
|
TensorFloat inputTensor = TensorFloat::CreateFromArray(dimsX, winrt::array_view<const float>(std::move(valuesX)));
|
|
bindings.Bind(winrt::hstring(L"X"), inputTensor);
|
|
|
|
auto outputValue = TensorFloat::Create();
|
|
WINML_EXPECT_NO_THROW(bindings.Bind(L"Y", outputValue));
|
|
|
|
// Evaluate the model
|
|
winrt::hstring correlationId;
|
|
WINML_EXPECT_NO_THROW(session.Evaluate(bindings, correlationId));
|
|
|
|
// Verify the result shape
|
|
WINML_EXPECT_EQUAL(expectedDimsY.size(), outputValue.Shape().Size());
|
|
for (uint32_t j = 0; j < outputValue.Shape().Size(); j++) {
|
|
WINML_EXPECT_EQUAL(expectedDimsY.at(j), outputValue.Shape().GetAt(j));
|
|
}
|
|
|
|
// Verify the result values
|
|
auto buffer = outputValue.GetAsVectorView();
|
|
WINML_EXPECT_TRUE(buffer != nullptr);
|
|
WINML_EXPECT_TRUE(std::equal(expectedValuesY.cbegin(), expectedValuesY.cend(), begin(buffer)));
|
|
|
|
// Release the model before operatorProvider goes out of scope
|
|
model = nullptr;
|
|
|
|
if (shapeInferenceContext) {
|
|
// Check that the shape inference context is closed and safely fails
|
|
MLOperatorEdgeDescription edgeDesc;
|
|
WINML_EXPECT_EQUAL(E_INVALIDARG, shapeInferenceContext->GetInputEdgeDescription(0, &edgeDesc));
|
|
}
|
|
}
|
|
}
|
|
|
|
const CustomOpsTestsApi& getapi() {
|
|
static CustomOpsTestsApi api =
|
|
{
|
|
CustomOpsScenarioTestsClassSetup,
|
|
CustomOperatorFusion,
|
|
CustomKernelWithBuiltInSchema,
|
|
CustomKernelWithCustomSchema
|
|
};
|
|
|
|
if (SkipGpuTests()) {
|
|
api.CustomOperatorFusion = SkipTest;
|
|
}
|
|
if (RuntimeParameterExists(L"noVideoFrameTests")) {
|
|
api.CustomOperatorFusion = SkipTest;
|
|
}
|
|
return api;
|
|
} |