onnxruntime/winml/adapter/WinMLAdapter.h

148 lines
No EOL
7.8 KiB
C++

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include "core/session/onnxruntime_c_api.h"
namespace Windows::AI::MachineLearning::Adapter {
TRACELOGGING_DECLARE_PROVIDER(winml_trace_logging_provider);
MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") IModelInfo : IUnknown{
// model metadata
virtual const char* STDMETHODCALLTYPE author() = 0;
virtual const char* STDMETHODCALLTYPE name() = 0;
virtual const char* STDMETHODCALLTYPE domain() = 0;
virtual const char* STDMETHODCALLTYPE description() = 0;
virtual int64_t STDMETHODCALLTYPE version() = 0;
virtual HRESULT STDMETHODCALLTYPE GetModelMetadata(ABI::Windows::Foundation::Collections::IMapView<HSTRING, HSTRING> ** metadata) = 0;
virtual HRESULT STDMETHODCALLTYPE GetInputFeatures(ABI::Windows::Foundation::Collections::IVectorView<winml::ILearningModelFeatureDescriptor> * *features) = 0;
virtual HRESULT STDMETHODCALLTYPE GetOutputFeatures(ABI::Windows::Foundation::Collections::IVectorView<winml::ILearningModelFeatureDescriptor> * *features) = 0;
};
MIDL_INTERFACE("a848faf6-5a2e-4a7f-b622-cc036f71e28a") IModelProto : IUnknown{
// this returns a weak ref
virtual onnx::ModelProto* STDMETHODCALLTYPE get() = 0;
// this returns the ownership without touching the reference and forgets about the object
virtual onnx::ModelProto* STDMETHODCALLTYPE detach() = 0;
};
MIDL_INTERFACE("6ec766ef-6365-42bf-b64f-ae85c015adb8") IInferenceSession : IUnknown {
virtual onnxruntime::InferenceSession* STDMETHODCALLTYPE get() = 0;
// the below returns a weak ref , DO NOT RELEASE IT
virtual HRESULT STDMETHODCALLTYPE GetOrtSession(OrtSession ** out) = 0;
virtual void STDMETHODCALLTYPE RegisterGraphTransformers() = 0;
virtual HRESULT STDMETHODCALLTYPE RegisterCustomRegistry(IMLOperatorRegistry * registry) = 0;
virtual HRESULT STDMETHODCALLTYPE LoadModel(IModelProto* model_proto) = 0;
virtual HRESULT STDMETHODCALLTYPE StartProfiling() = 0;
virtual HRESULT STDMETHODCALLTYPE EndProfiling() = 0;
virtual void STDMETHODCALLTYPE FlushContext(onnxruntime::IExecutionProvider * dml_provider) = 0;
virtual void STDMETHODCALLTYPE TrimUploadHeap(onnxruntime::IExecutionProvider* dml_provider) = 0;
virtual void STDMETHODCALLTYPE ReleaseCompletedReferences(onnxruntime::IExecutionProvider* dml_provider) = 0;
virtual HRESULT STDMETHODCALLTYPE CopyOneInputAcrossDevices(const char* input_name,
const OrtValue* orig_mlvalue, OrtValue** new_mlvalue) = 0;
};
// The IOrtSessionBuilder offers an abstraction over the creation of
// InferenceSession, that enables the creation of the session based on a device (CPU/DML).
MIDL_INTERFACE("2746f03a-7e08-4564-b5d0-c670fef116ee") IOrtSessionBuilder : IUnknown {
virtual HRESULT STDMETHODCALLTYPE CreateSessionOptions(
OrtSessionOptions ** options) = 0;
virtual HRESULT STDMETHODCALLTYPE CreateSession(
OrtSessionOptions * options,
IInferenceSession** session,
onnxruntime::IExecutionProvider** provider) = 0;
virtual HRESULT STDMETHODCALLTYPE Initialize(
IInferenceSession* session,
onnxruntime::IExecutionProvider* provider) = 0;
};
MIDL_INTERFACE("b19385e7-d9af-441a-ba7f-3993c7b1c9db") IWinMLAdapter : IUnknown {
virtual void STDMETHODCALLTYPE EnableDebugOutput() = 0;
virtual HRESULT STDMETHODCALLTYPE EnsureModelDeviceCompatibility(
winml::LearningModel const& model,
IModelProto* p_model_proto,
bool is_float16_supported) = 0;
// factory method for creating an ortsessionbuilder from a device
virtual HRESULT STDMETHODCALLTYPE CreateOrtSessionBuilder(
ID3D12Device* device,
ID3D12CommandQueue* queue,
IOrtSessionBuilder** session_builder) = 0;
// factory methods for creating model protos
virtual HRESULT STDMETHODCALLTYPE CreateModelProto(const char* path, IModelProto** model_proto) = 0;
virtual HRESULT STDMETHODCALLTYPE CreateModelProto(ABI::Windows::Storage::Streams::IRandomAccessStreamReference* stream_reference, IModelProto** model_proto) = 0;
virtual HRESULT STDMETHODCALLTYPE CreateModelProto(IModelProto * model_proto_in, IModelProto** model_proto) = 0;
virtual HRESULT STDMETHODCALLTYPE CreateModelInfo(IModelProto * model_proto, IModelInfo ** model_info) = 0;
// Data types
// custom ops
virtual HRESULT STDMETHODCALLTYPE GetCustomRegistry(IMLOperatorRegistry** registry) = 0;
virtual HRESULT STDMETHODCALLTYPE GetOperatorRegistry(ILearningModelOperatorProviderNative * operator_provider_native, IMLOperatorRegistry * *registry) = 0;
// dml ep hooks
virtual void* STDMETHODCALLTYPE CreateGPUAllocationFromD3DResource(ID3D12Resource* pResource) = 0;
virtual void STDMETHODCALLTYPE FreeGPUAllocation(void* ptr) = 0;
virtual HRESULT STDMETHODCALLTYPE CopyTensor(onnxruntime::IExecutionProvider* provider, OrtValue* src, OrtValue* dst) = 0;
// note: this returns a weak ref
virtual ID3D12Resource* STDMETHODCALLTYPE GetD3D12ResourceFromAllocation(onnxruntime::IExecutionProvider * provider, void* allocation) = 0;
// schema overrides (dml does this for us)
virtual HRESULT STDMETHODCALLTYPE OverrideSchemaInferenceFunctions() = 0;
// proposed adapter. uses the cross plat ABI currencies
virtual HRESULT STDMETHODCALLTYPE GetProviderMemoryInfo(onnxruntime::IExecutionProvider * provider, OrtMemoryInfo** memory_info) = 0;
virtual HRESULT STDMETHODCALLTYPE GetProviderAllocator(onnxruntime::IExecutionProvider * provider, OrtAllocator** allocator) = 0;
virtual HRESULT STDMETHODCALLTYPE GetValueMemoryInfo(const OrtValue * value, OrtMemoryInfo** memory_info) = 0;
virtual HRESULT STDMETHODCALLTYPE GetMapType(const OrtValue * ort_value, ONNXTensorElementDataType * key_type, ONNXTensorElementDataType * value_type) = 0;
virtual HRESULT STDMETHODCALLTYPE GetVectorMapType(const OrtValue * ort_value, ONNXTensorElementDataType * key_type, ONNXTensorElementDataType * value_type) = 0;
//virtual HRESULT STDMETHODCALLTYPE CreateTensorFromMap(IInspectable * map, OrtValue * *ort_value) = 0;
//virtual HRESULT STDMETHODCALLTYPE CreateTensorFromSequence(IInspectable * sequence, OrtValue * *ort_value) = 0;
};
extern "C"
__declspec(dllexport) HRESULT STDMETHODCALLTYPE OrtGetWinMLAdapter(IWinMLAdapter** adapter);
class InferenceSession : public Microsoft::WRL::RuntimeClass <
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
IInferenceSession> {
public:
InferenceSession(onnxruntime::InferenceSession * session);
onnxruntime::InferenceSession* STDMETHODCALLTYPE get() noexcept override {
return session_.get();
}
HRESULT STDMETHODCALLTYPE GetOrtSession(OrtSession ** out) noexcept override {
// (OrtSession *) are really (InferenceSession *) as well
*out = reinterpret_cast<OrtSession*>(session_.get());
return S_OK;
}
void STDMETHODCALLTYPE RegisterGraphTransformers() override;
HRESULT STDMETHODCALLTYPE RegisterCustomRegistry(IMLOperatorRegistry* registry) override;
HRESULT STDMETHODCALLTYPE LoadModel(IModelProto* model_proto) override;
HRESULT STDMETHODCALLTYPE StartProfiling() override;
HRESULT STDMETHODCALLTYPE EndProfiling() override;
void STDMETHODCALLTYPE FlushContext(onnxruntime::IExecutionProvider* dml_provider) override;
void STDMETHODCALLTYPE TrimUploadHeap(onnxruntime::IExecutionProvider* dml_provider) override;
void STDMETHODCALLTYPE ReleaseCompletedReferences(onnxruntime::IExecutionProvider* dml_provider) override;
HRESULT STDMETHODCALLTYPE CopyOneInputAcrossDevices(const char* input_name,
const OrtValue* orig_mlvalue, OrtValue** new_mlvalue) override;
private:
std::shared_ptr<onnxruntime::InferenceSession> session_;
};
} // namespace Windows::AI::MachineLearning::Adapter