mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
* commetns for dml graph transformer fixed ort value passing using the allocatir info * fixed and coded maps and sequences across the abi * cleaned up w4's cleaned up the model info ABI delayload directml.dll from winml
104 lines
3.6 KiB
C++
104 lines
3.6 KiB
C++
//
|
|
// Implements a custom operator kernel which counts the number of calls to Compute(), but otherwise is a no-op.
|
|
//
|
|
|
|
#pragma once
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
template <typename T>
|
|
struct NullShapeInferrer : winrt::implements<NullShapeInferrer<T>, IMLOperatorShapeInferrer>
|
|
{
|
|
STDMETHOD(InferOutputShapes)(IMLOperatorShapeInferenceContext* context) noexcept
|
|
{
|
|
EXPECT_NO_THROW(OperatorHelper::ShapeInferenceFunction<T>(context));
|
|
return S_OK;
|
|
}
|
|
};
|
|
|
|
struct NullOperator : winrt::implements<NullOperator, IMLOperatorKernel>
|
|
{
|
|
NullOperator(std::atomic<uint32_t>* callCount) : m_callCount(callCount) {}
|
|
|
|
STDMETHOD(Compute)(IMLOperatorKernelContext* context)
|
|
{
|
|
winrt::com_ptr<IMLOperatorTensor> outputTensor;
|
|
EXPECT_HRESULT_SUCCEEDED(context->GetOutputTensor(0, outputTensor.put()));
|
|
|
|
++(*m_callCount);
|
|
return S_OK;
|
|
}
|
|
|
|
private:
|
|
std::atomic<uint32_t>* m_callCount;
|
|
};
|
|
|
|
struct NullOperatorFactory : winrt::implements<NullOperatorFactory, IMLOperatorKernelFactory>
|
|
{
|
|
NullOperatorFactory(std::atomic<uint32_t>* callCount) : m_callCount(callCount) {}
|
|
|
|
STDMETHOD(CreateKernel)(
|
|
IMLOperatorKernelCreationContext* context,
|
|
IMLOperatorKernel** kernel)
|
|
{
|
|
ORT_UNUSED_PARAMETER(context);
|
|
auto op = winrt::make<NullOperator>(m_callCount);
|
|
op.copy_to(kernel);
|
|
return S_OK;
|
|
}
|
|
|
|
static MLOperatorEdgeDescription CreateEdgeDescriptor(MLOperatorEdgeType type, MLOperatorTensorDataType dataType)
|
|
{
|
|
ORT_UNUSED_PARAMETER(type);
|
|
MLOperatorEdgeDescription desc;
|
|
desc.edgeType = MLOperatorEdgeType::Tensor;
|
|
desc.tensorDataType = dataType;
|
|
return desc;
|
|
}
|
|
|
|
static void RegisterKernel(
|
|
const char* name,
|
|
const char* domain,
|
|
int versionSince,
|
|
winrt::com_ptr<IMLOperatorRegistry> registry,
|
|
winrt::com_ptr<IMLOperatorShapeInferrer> shapeInferrer,
|
|
std::atomic<uint32_t>* callCount)
|
|
{
|
|
MLOperatorKernelDescription kernelDescription;
|
|
kernelDescription.domain = domain;
|
|
kernelDescription.name = name;
|
|
kernelDescription.minimumOperatorSetVersion = versionSince;
|
|
kernelDescription.executionType = MLOperatorExecutionType::D3D12;
|
|
|
|
MLOperatorEdgeTypeConstrant typeConstraint;
|
|
typeConstraint.typeLabel = "T";
|
|
std::vector<MLOperatorEdgeDescription> allowedEdges
|
|
{
|
|
CreateEdgeDescriptor(MLOperatorEdgeType::Tensor, MLOperatorTensorDataType::Double),
|
|
CreateEdgeDescriptor(MLOperatorEdgeType::Tensor, MLOperatorTensorDataType::Float),
|
|
CreateEdgeDescriptor(MLOperatorEdgeType::Tensor, MLOperatorTensorDataType::Float16)
|
|
};
|
|
typeConstraint.allowedTypes = allowedEdges.data();
|
|
typeConstraint.allowedTypeCount = static_cast<uint32_t>(allowedEdges.size());
|
|
|
|
std::vector<MLOperatorEdgeTypeConstrant> typeConstraints{ typeConstraint };
|
|
kernelDescription.typeConstraints = typeConstraints.data();
|
|
kernelDescription.typeConstraintCount = static_cast<uint32_t>(typeConstraints.size());
|
|
|
|
kernelDescription.defaultAttributes = nullptr;
|
|
kernelDescription.defaultAttributeCount = 0;
|
|
kernelDescription.options = MLOperatorKernelOptions::None;
|
|
kernelDescription.executionOptions = 0;
|
|
|
|
auto factory = winrt::make<NullOperatorFactory>(callCount);
|
|
|
|
EXPECT_HRESULT_SUCCEEDED(registry->RegisterOperatorKernel(
|
|
&kernelDescription,
|
|
factory.get(),
|
|
shapeInferrer.get()
|
|
));
|
|
}
|
|
|
|
private:
|
|
std::atomic<uint32_t>* m_callCount;
|
|
};
|