mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
148 lines
No EOL
7.8 KiB
C++
148 lines
No EOL
7.8 KiB
C++
// Copyright (c) Microsoft Corporation.
|
|
// Licensed under the MIT License.
|
|
|
|
#pragma once
|
|
|
|
#include "core/session/onnxruntime_c_api.h"
|
|
|
|
namespace Windows::AI::MachineLearning::Adapter {
|
|
TRACELOGGING_DECLARE_PROVIDER(winml_trace_logging_provider);
|
|
|
|
MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") IModelInfo : IUnknown{
|
|
// model metadata
|
|
virtual const char* STDMETHODCALLTYPE author() = 0;
|
|
virtual const char* STDMETHODCALLTYPE name() = 0;
|
|
virtual const char* STDMETHODCALLTYPE domain() = 0;
|
|
virtual const char* STDMETHODCALLTYPE description() = 0;
|
|
virtual int64_t STDMETHODCALLTYPE version() = 0;
|
|
virtual HRESULT STDMETHODCALLTYPE GetModelMetadata(ABI::Windows::Foundation::Collections::IMapView<HSTRING, HSTRING> ** metadata) = 0;
|
|
virtual HRESULT STDMETHODCALLTYPE GetInputFeatures(ABI::Windows::Foundation::Collections::IVectorView<winml::ILearningModelFeatureDescriptor> * *features) = 0;
|
|
virtual HRESULT STDMETHODCALLTYPE GetOutputFeatures(ABI::Windows::Foundation::Collections::IVectorView<winml::ILearningModelFeatureDescriptor> * *features) = 0;
|
|
};
|
|
|
|
MIDL_INTERFACE("a848faf6-5a2e-4a7f-b622-cc036f71e28a") IModelProto : IUnknown{
|
|
// this returns a weak ref
|
|
virtual onnx::ModelProto* STDMETHODCALLTYPE get() = 0;
|
|
// this returns the ownership without touching the reference and forgets about the object
|
|
virtual onnx::ModelProto* STDMETHODCALLTYPE detach() = 0;
|
|
};
|
|
|
|
MIDL_INTERFACE("6ec766ef-6365-42bf-b64f-ae85c015adb8") IInferenceSession : IUnknown {
|
|
virtual onnxruntime::InferenceSession* STDMETHODCALLTYPE get() = 0;
|
|
// the below returns a weak ref , DO NOT RELEASE IT
|
|
virtual HRESULT STDMETHODCALLTYPE GetOrtSession(OrtSession ** out) = 0;
|
|
virtual void STDMETHODCALLTYPE RegisterGraphTransformers() = 0;
|
|
virtual HRESULT STDMETHODCALLTYPE RegisterCustomRegistry(IMLOperatorRegistry * registry) = 0;
|
|
virtual HRESULT STDMETHODCALLTYPE LoadModel(IModelProto* model_proto) = 0;
|
|
virtual HRESULT STDMETHODCALLTYPE StartProfiling() = 0;
|
|
virtual HRESULT STDMETHODCALLTYPE EndProfiling() = 0;
|
|
virtual void STDMETHODCALLTYPE FlushContext(onnxruntime::IExecutionProvider * dml_provider) = 0;
|
|
virtual void STDMETHODCALLTYPE TrimUploadHeap(onnxruntime::IExecutionProvider* dml_provider) = 0;
|
|
virtual void STDMETHODCALLTYPE ReleaseCompletedReferences(onnxruntime::IExecutionProvider* dml_provider) = 0;
|
|
virtual HRESULT STDMETHODCALLTYPE CopyOneInputAcrossDevices(const char* input_name,
|
|
const OrtValue* orig_mlvalue, OrtValue** new_mlvalue) = 0;
|
|
};
|
|
|
|
// The IOrtSessionBuilder offers an abstraction over the creation of
|
|
// InferenceSession, that enables the creation of the session based on a device (CPU/DML).
|
|
MIDL_INTERFACE("2746f03a-7e08-4564-b5d0-c670fef116ee") IOrtSessionBuilder : IUnknown {
|
|
|
|
virtual HRESULT STDMETHODCALLTYPE CreateSessionOptions(
|
|
OrtSessionOptions ** options) = 0;
|
|
|
|
virtual HRESULT STDMETHODCALLTYPE CreateSession(
|
|
OrtSessionOptions * options,
|
|
IInferenceSession** session,
|
|
onnxruntime::IExecutionProvider** provider) = 0;
|
|
|
|
virtual HRESULT STDMETHODCALLTYPE Initialize(
|
|
IInferenceSession* session,
|
|
onnxruntime::IExecutionProvider* provider) = 0;
|
|
};
|
|
|
|
|
|
MIDL_INTERFACE("b19385e7-d9af-441a-ba7f-3993c7b1c9db") IWinMLAdapter : IUnknown {
|
|
|
|
virtual void STDMETHODCALLTYPE EnableDebugOutput() = 0;
|
|
|
|
virtual HRESULT STDMETHODCALLTYPE EnsureModelDeviceCompatibility(
|
|
winml::LearningModel const& model,
|
|
IModelProto* p_model_proto,
|
|
bool is_float16_supported) = 0;
|
|
|
|
// factory method for creating an ortsessionbuilder from a device
|
|
virtual HRESULT STDMETHODCALLTYPE CreateOrtSessionBuilder(
|
|
ID3D12Device* device,
|
|
ID3D12CommandQueue* queue,
|
|
IOrtSessionBuilder** session_builder) = 0;
|
|
|
|
// factory methods for creating model protos
|
|
virtual HRESULT STDMETHODCALLTYPE CreateModelProto(const char* path, IModelProto** model_proto) = 0;
|
|
virtual HRESULT STDMETHODCALLTYPE CreateModelProto(ABI::Windows::Storage::Streams::IRandomAccessStreamReference* stream_reference, IModelProto** model_proto) = 0;
|
|
virtual HRESULT STDMETHODCALLTYPE CreateModelProto(IModelProto * model_proto_in, IModelProto** model_proto) = 0;
|
|
virtual HRESULT STDMETHODCALLTYPE CreateModelInfo(IModelProto * model_proto, IModelInfo ** model_info) = 0;
|
|
|
|
// Data types
|
|
|
|
// custom ops
|
|
virtual HRESULT STDMETHODCALLTYPE GetCustomRegistry(IMLOperatorRegistry** registry) = 0;
|
|
virtual HRESULT STDMETHODCALLTYPE GetOperatorRegistry(ILearningModelOperatorProviderNative * operator_provider_native, IMLOperatorRegistry * *registry) = 0;
|
|
|
|
// dml ep hooks
|
|
virtual void* STDMETHODCALLTYPE CreateGPUAllocationFromD3DResource(ID3D12Resource* pResource) = 0;
|
|
virtual void STDMETHODCALLTYPE FreeGPUAllocation(void* ptr) = 0;
|
|
virtual HRESULT STDMETHODCALLTYPE CopyTensor(onnxruntime::IExecutionProvider* provider, OrtValue* src, OrtValue* dst) = 0;
|
|
// note: this returns a weak ref
|
|
virtual ID3D12Resource* STDMETHODCALLTYPE GetD3D12ResourceFromAllocation(onnxruntime::IExecutionProvider * provider, void* allocation) = 0;
|
|
|
|
// schema overrides (dml does this for us)
|
|
virtual HRESULT STDMETHODCALLTYPE OverrideSchemaInferenceFunctions() = 0;
|
|
|
|
// proposed adapter. uses the cross plat ABI currencies
|
|
virtual HRESULT STDMETHODCALLTYPE GetProviderMemoryInfo(onnxruntime::IExecutionProvider * provider, OrtMemoryInfo** memory_info) = 0;
|
|
virtual HRESULT STDMETHODCALLTYPE GetProviderAllocator(onnxruntime::IExecutionProvider * provider, OrtAllocator** allocator) = 0;
|
|
virtual HRESULT STDMETHODCALLTYPE GetValueMemoryInfo(const OrtValue * value, OrtMemoryInfo** memory_info) = 0;
|
|
virtual HRESULT STDMETHODCALLTYPE GetMapType(const OrtValue * ort_value, ONNXTensorElementDataType * key_type, ONNXTensorElementDataType * value_type) = 0;
|
|
virtual HRESULT STDMETHODCALLTYPE GetVectorMapType(const OrtValue * ort_value, ONNXTensorElementDataType * key_type, ONNXTensorElementDataType * value_type) = 0;
|
|
//virtual HRESULT STDMETHODCALLTYPE CreateTensorFromMap(IInspectable * map, OrtValue * *ort_value) = 0;
|
|
//virtual HRESULT STDMETHODCALLTYPE CreateTensorFromSequence(IInspectable * sequence, OrtValue * *ort_value) = 0;
|
|
};
|
|
|
|
extern "C"
|
|
__declspec(dllexport) HRESULT STDMETHODCALLTYPE OrtGetWinMLAdapter(IWinMLAdapter** adapter);
|
|
|
|
class InferenceSession : public Microsoft::WRL::RuntimeClass <
|
|
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
|
|
IInferenceSession> {
|
|
|
|
public:
|
|
|
|
InferenceSession(onnxruntime::InferenceSession * session);
|
|
|
|
onnxruntime::InferenceSession* STDMETHODCALLTYPE get() noexcept override {
|
|
return session_.get();
|
|
}
|
|
|
|
HRESULT STDMETHODCALLTYPE GetOrtSession(OrtSession ** out) noexcept override {
|
|
// (OrtSession *) are really (InferenceSession *) as well
|
|
*out = reinterpret_cast<OrtSession*>(session_.get());
|
|
return S_OK;
|
|
}
|
|
|
|
void STDMETHODCALLTYPE RegisterGraphTransformers() override;
|
|
HRESULT STDMETHODCALLTYPE RegisterCustomRegistry(IMLOperatorRegistry* registry) override;
|
|
HRESULT STDMETHODCALLTYPE LoadModel(IModelProto* model_proto) override;
|
|
HRESULT STDMETHODCALLTYPE StartProfiling() override;
|
|
HRESULT STDMETHODCALLTYPE EndProfiling() override;
|
|
void STDMETHODCALLTYPE FlushContext(onnxruntime::IExecutionProvider* dml_provider) override;
|
|
void STDMETHODCALLTYPE TrimUploadHeap(onnxruntime::IExecutionProvider* dml_provider) override;
|
|
void STDMETHODCALLTYPE ReleaseCompletedReferences(onnxruntime::IExecutionProvider* dml_provider) override;
|
|
HRESULT STDMETHODCALLTYPE CopyOneInputAcrossDevices(const char* input_name,
|
|
const OrtValue* orig_mlvalue, OrtValue** new_mlvalue) override;
|
|
|
|
|
|
private:
|
|
std::shared_ptr<onnxruntime::InferenceSession> session_;
|
|
};
|
|
|
|
} // namespace Windows::AI::MachineLearning::Adapter
|