// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. // // Implements a custom operator kernel which counts the number of calls to Compute(), but otherwise is a no-op. // #pragma once #include "test.h" template struct NullShapeInferrer : winrt::implements, IMLOperatorShapeInferrer> { STDMETHOD(InferOutputShapes)(IMLOperatorShapeInferenceContext* context) noexcept { WINML_EXPECT_NO_THROW(OperatorHelper::ShapeInferenceFunction(context)); return S_OK; } }; struct NullOperator : winrt::implements { NullOperator(std::atomic* callCount) : m_callCount(callCount) {} STDMETHOD(Compute)(IMLOperatorKernelContext* context) { winrt::com_ptr outputTensor; WINML_EXPECT_HRESULT_SUCCEEDED(context->GetOutputTensor(0, outputTensor.put())); ++(*m_callCount); return S_OK; } private: std::atomic* m_callCount; }; struct NullOperatorFactory : winrt::implements { NullOperatorFactory(std::atomic* callCount) : m_callCount(callCount) {} STDMETHOD(CreateKernel)(IMLOperatorKernelCreationContext* context, IMLOperatorKernel** kernel) { ORT_UNUSED_PARAMETER(context); auto op = winrt::make(m_callCount); op.copy_to(kernel); return S_OK; } static MLOperatorEdgeDescription CreateEdgeDescriptor(MLOperatorEdgeType type, MLOperatorTensorDataType dataType) { ORT_UNUSED_PARAMETER(type); MLOperatorEdgeDescription desc; desc.edgeType = MLOperatorEdgeType::Tensor; desc.tensorDataType = dataType; return desc; } static void RegisterKernel( const char* name, const char* domain, int versionSince, winrt::com_ptr registry, winrt::com_ptr shapeInferrer, std::atomic* callCount ) { MLOperatorKernelDescription kernelDescription; kernelDescription.domain = domain; kernelDescription.name = name; kernelDescription.minimumOperatorSetVersion = versionSince; kernelDescription.executionType = MLOperatorExecutionType::D3D12; MLOperatorEdgeTypeConstrant typeConstraint; typeConstraint.typeLabel = "T"; std::vector allowedEdges{ CreateEdgeDescriptor(MLOperatorEdgeType::Tensor, MLOperatorTensorDataType::Double), CreateEdgeDescriptor(MLOperatorEdgeType::Tensor, MLOperatorTensorDataType::Float), CreateEdgeDescriptor(MLOperatorEdgeType::Tensor, MLOperatorTensorDataType::Float16) }; typeConstraint.allowedTypes = allowedEdges.data(); typeConstraint.allowedTypeCount = static_cast(allowedEdges.size()); std::vector typeConstraints{typeConstraint}; kernelDescription.typeConstraints = typeConstraints.data(); kernelDescription.typeConstraintCount = static_cast(typeConstraints.size()); kernelDescription.defaultAttributes = nullptr; kernelDescription.defaultAttributeCount = 0; kernelDescription.options = MLOperatorKernelOptions::None; kernelDescription.executionOptions = 0; auto factory = winrt::make(callCount); WINML_EXPECT_HRESULT_SUCCEEDED( registry->RegisterOperatorKernel(&kernelDescription, factory.get(), shapeInferrer.get()) ); } private: std::atomic* m_callCount; };