onnxruntime/winml/adapter/winml_adapter_c_api.h
Justin Chu 2575b9aaa1
Improve comments in winml/ (#17163)
Follow up of #17144. Manually fixed indentation in block comments and
replaced all tabs with spaces.
2023-08-15 23:30:56 -04:00

578 lines
23 KiB
C

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/session/onnxruntime_c_api.h"
#include "winrt/windows.foundation.collections.h"
/**
* All APIs exported by winml_adapter_c_api.h are part of the private interface dedicated to supporting the WinML API.
* This contract is subject to change based on the needs of the WinML API and is not intended for direct use by callers
* of the onnxruntime c-api and usage of APIs in this header are *not* supported by the onnxruntime product.
*/
ORT_RUNTIME_CLASS(Model);
ORT_RUNTIME_CLASS(ExecutionProvider);
ORT_RUNTIME_CLASS(ThreadPool);
struct WinmlAdapterApi;
typedef struct WinmlAdapterApi WinmlAdapterApi;
struct ID3D12Resource;
struct ID3D12Device;
struct ID3D12CommandQueue;
struct IMLOperatorRegistry;
// TODO: Must match onnxruntime::profiling::EventRecord
enum OrtProfilerEventCategory {
SESSION_EVENT = 0,
NODE_EVENT,
EVENT_CATEGORY_MAX
};
struct OrtProfilerEventRecord {
OrtProfilerEventCategory category_;
const char* category_name_;
int64_t duration_;
int64_t time_span_;
const char* event_name_;
int32_t process_id_;
int32_t thread_id_;
const char* op_name_;
const char* execution_provider_;
};
typedef void(ORT_API_CALL* OrtProfilingFunction)(const OrtProfilerEventRecord* event_record);
enum class ThreadPoolType : uint8_t {
INTRA_OP,
INTER_OP
};
struct OrtThreadPoolOptions {
//0: Use default setting. (All the physical cores or half of the logical cores)
//1: Don't create thread pool
//n: Create a thread pool with n threads.
int thread_pool_size = 0;
//If it is true and thread_pool_size = 0, populate the thread affinity information in ThreadOptions.
//Otherwise if the thread_options has affinity information, we'll use it and set it.
//In the other case, don't set affinity
bool auto_set_affinity = false;
//If it is true, the thread pool will spin a while after the queue became empty.
bool allow_spinning = true;
//It it is non-negative, thread pool will split a task by a decreasing block size
//of remaining_of_total_iterations / (num_of_threads * dynamic_block_base_)
int dynamic_block_base_ = 0;
unsigned int stack_size = 0;
const ORTCHAR_T* name = nullptr;
// Set or unset denormal as zero
bool set_denormal_as_zero = false;
};
struct WinmlAdapterApi {
/**
* OverrideSchema
* This api is used to override schema inference functions for a variety of ops across opsets.
* This exists because certain ops were failing to infer schemas and caused performance
* issues for DML as it was forced to create resources during evaluation.
* This can be removed when schema inference functions have been updated.
*/
OrtStatus*(ORT_API_CALL* OverrideSchema)() NO_EXCEPTION;
/**
* EnvConfigureCustomLoggerAndProfiler
* This api is used to add a custom logger and profiler to the ors environment.
* This exists because existing methods on the c-abi to create the environment only support a custom logger.
* Since WinML hooks the profiler events, we expose the profiler and an associated profiling function.
*/
OrtStatus*(ORT_API_CALL* EnvConfigureCustomLoggerAndProfiler)(
_In_ OrtEnv* env,
OrtLoggingFunction logging_function,
OrtProfilingFunction profiling_function,
_In_opt_ void* logger_param,
OrtLoggingLevel default_warning_level,
_In_ const char* logid,
_Outptr_ OrtEnv** out
)NO_EXCEPTION;
// OrtModel methods
/**
* CreateModelFromPath
* This api creates an OrtModel based on a specified model path.
* There is no inferencing or evaluation setup performed. Only ONNX load is done to reflect on the model's inputs/outputs and other properties.
* This is used by WinML to support model reflection APIs.
*/
OrtStatus*(ORT_API_CALL* CreateModelFromPath)(_In_ const char* model_path, _In_ size_t size, _Outptr_ OrtModel** out)
NO_EXCEPTION;
/**
* CreateModelFromData
* This api creates an OrtModel from a buffer.
* There is no inferencing or evaluation setup performed. Only ONNX load is done to reflect on the model's inputs/outputs and other properties.
* This is used by WinML to support model reflection APIs.
*/
OrtStatus*(ORT_API_CALL* CreateModelFromData)(_In_opt_ void* data, _In_ size_t size, _Outptr_ OrtModel** out)
NO_EXCEPTION;
/**
* CloneModel
* This api copies the OrtModel along with its internal proto buffer and cached metadata.
* The OrtSession type expects to own the model proto buffer.
* WinML uses this to yield copies of the model proto held by OrtModel to OrtSession.
*/
OrtStatus*(ORT_API_CALL* CloneModel)(_In_ const OrtModel* in, _Outptr_ OrtModel** out)NO_EXCEPTION;
/**
* ModelGetAuthor
* This api gets the model author from the OrtModel.
* This is used by WinML to support model reflection APIs.
*/
OrtStatus*(ORT_API_CALL* ModelGetAuthor)(
_In_ const OrtModel* model, _Out_ const char** const author, _Out_ size_t* len
)NO_EXCEPTION;
/**
* ModelGetName
* This api gets the model name from the OrtModel.
* This is used by WinML to support model reflection APIs.
*/
OrtStatus*(ORT_API_CALL* ModelGetName)(_In_ const OrtModel* model, _Out_ const char** const name, _Out_ size_t* len)
NO_EXCEPTION;
/**
* ModelSetName
* This api set the model name from the OrtModel.
* This is used by the Windows ML Samples Gallery to change the model name for telemetry.
*/
OrtStatus*(ORT_API_CALL* ModelSetName)(_In_ const OrtModel* model, _In_ const char* name)NO_EXCEPTION;
/**
* ModelGetDomain
* This api gets the model domain from the OrtModel.
* This is used by WinML to support model reflection APIs.
*/
OrtStatus*(ORT_API_CALL* ModelGetDomain)(
_In_ const OrtModel* model, _Out_ const char** const domain, _Out_ size_t* len
)NO_EXCEPTION;
/**
* ModelGetDescription
* This api gets the model description from the OrtModel.
* This is used by WinML to support model reflection APIs.
*/
OrtStatus*(ORT_API_CALL* ModelGetDescription)(
_In_ const OrtModel* model, _Out_ const char** const description, _Out_ size_t* len
)NO_EXCEPTION;
/**
* ModelGetVersion
* This api gets the model version from the OrtModel.
* This is used by WinML to support model reflection APIs.
*/
OrtStatus*(ORT_API_CALL* ModelGetVersion)(_In_ const OrtModel* model, _Out_ int64_t* version)NO_EXCEPTION;
/**
* ModelGetInputCount
* This api gets the number of inputs from the OrtModel. It closely matches the API of a similar name similar name for retrieving model metadata from OrtSession.
* This is used by WinML to support model reflection APIs.
*/
OrtStatus*(ORT_API_CALL* ModelGetInputCount)(_In_ const OrtModel* model, _Out_ size_t* count)NO_EXCEPTION;
/**
* ModelGetOutputCount
* This api gets the number of outputs from the OrtModel. It closely matches the API of a similar name for retrieving model metadata from OrtSession.
* This is used by WinML to support model reflection APIs.
*/
OrtStatus*(ORT_API_CALL* ModelGetOutputCount)(_In_ const OrtModel* model, _Out_ size_t* count)NO_EXCEPTION;
/**
* ModelGetInputName
* This api gets the input name from the OrtModel given an index. It closely matches the API of a similar name for retrieving model metadata from OrtSession.
* This is used by WinML to support model reflection APIs.
*/
OrtStatus*(ORT_API_CALL* ModelGetInputName)(
_In_ const OrtModel* model, _In_ size_t index, _Out_ const char** input_name, _Out_ size_t* count
)NO_EXCEPTION;
/**
* ModelGetOutputName
* This api gets the output name from the OrtModel given an index. It closely matches the API of a similar name for retrieving model metadata from OrtSession.
* This is used by WinML to support model reflection APIs.
*/
OrtStatus*(ORT_API_CALL* ModelGetOutputName)(
_In_ const OrtModel* model, _In_ size_t index, _Out_ const char** output_name, _Out_ size_t* count
)NO_EXCEPTION;
/**
* ModelGetInputDescription
* This api gets the input description from the OrtModel given an index. It closely matches the API of a similar name for retrieving model metadata from OrtSession.
* This is used by WinML to support model reflection APIs.
*/
OrtStatus*(ORT_API_CALL* ModelGetInputDescription)(
_In_ const OrtModel* model, _In_ size_t index, _Out_ const char** input_description, _Out_ size_t* count
)NO_EXCEPTION;
/**
* ModelGetOutputDescription
* This api gets the output description from the OrtModel given an index. It closely matches the API of a similar name for retrieving model metadata from OrtSession.
* This is used by WinML to support model reflection APIs.
*/
OrtStatus*(ORT_API_CALL* ModelGetOutputDescription)(
_In_ const OrtModel* model, _In_ size_t index, _Out_ const char** output_description, _Out_ size_t* count
)NO_EXCEPTION;
/**
* ModelGetInputTypeInfo
* This api gets the input OrtTypeInfo from the OrtModel given an index. It closely matches the API of a similar name for retrieving model metadata from OrtSession.
* This is used by WinML to support model reflection APIs.
*/
OrtStatus*(ORT_API_CALL* ModelGetInputTypeInfo)(
_In_ const OrtModel* model, _In_ size_t index, _Outptr_ OrtTypeInfo** type_info
)NO_EXCEPTION;
/**
* ModelGetOutputTypeInfo
* This api gets the output OrtTypeInfo from the OrtModel given an index. It closely matches the API of a similar name for retrieving model metadata from OrtSession.
* This is used by WinML to support model reflection APIs.
*/
OrtStatus*(ORT_API_CALL* ModelGetOutputTypeInfo)(
_In_ const OrtModel* model, _In_ size_t index, _Outptr_ OrtTypeInfo** type_info
)NO_EXCEPTION;
/**
* ModelGetMetadataCount
* This api gets the number of metadata entries from the OrtModel.
* This is used by WinML to support model reflection APIs.
*/
OrtStatus*(ORT_API_CALL* ModelGetMetadataCount)(_In_ const OrtModel* model, _Out_ size_t* count)NO_EXCEPTION;
/**
* ModelGetMetadata
* This api gets the model metadata from the OrtModel.
* This is used by WinML to deduce whether model input and output formats are supported by the WinML tensorization code paths.
*/
OrtStatus*(ORT_API_CALL* ModelGetMetadata)(
_In_ const OrtModel* model,
_Out_ size_t count,
_Out_ const char** const key,
_Out_ size_t* key_len,
_Out_ const char** const value,
_Out_ size_t* value_len
)NO_EXCEPTION;
/**
* ModelEnsureNoFloat16
* This api checks whether the model requires float 16 support.
* This is used by WinML to fail gracefully when float 16 support is not available on the device.
*
* Can this API be moved into the EP during session initialization. Currently we do an early fp16 check to avoid initialization when it is not supported.
*/
OrtStatus*(ORT_API_CALL* ModelEnsureNoFloat16)(_In_ const OrtModel* model)NO_EXCEPTION;
/**
* SaveModel
* This api save the model to the fiven file
*/
OrtStatus*(ORT_API_CALL* SaveModel)(_In_ const OrtModel* in, _In_ const wchar_t* const file_name, _In_ size_t len)
NO_EXCEPTION;
// OrtSessionOptions methods
/**
* OrtSessionOptionsAppendExecutionProvider_CPU
* This api is used to add the cpu EP to OrtSessionOptions so that WinML Gpu session are configures with CPU fallback.
*/
OrtStatus*(ORT_API_CALL* OrtSessionOptionsAppendExecutionProvider_CPU)(_In_ OrtSessionOptions* options, int use_arena)
NO_EXCEPTION;
/**
* OrtSessionOptionsAppendExecutionProvider_DML
* This api is used to add the DML EP to OrtSessionOptions.
*/
OrtStatus*(ORT_API_CALL* OrtSessionOptionsAppendExecutionProvider_DML)(
_In_ OrtSessionOptions* options, ID3D12Device* device, ID3D12CommandQueue* queue, bool metacommands_enabled
)NO_EXCEPTION;
// OrtSession methods
/**
* CreateSessionWithoutModel
* This api is used to create a Session that is completely uninitialized. While there are other Session creation APIs in the
* c-abi, WinML uses this so that it can perform optimizations prior to loading the model, and initializing.
* Moreover, WinML needs a new api to support the OrtModel type, and prevent the parsing model protobufs again on session creation.
*/
OrtStatus*(ORT_API_CALL* CreateSessionWithoutModel)(
_In_ OrtEnv* env,
_In_ const OrtSessionOptions* options,
_In_ OrtThreadPool* inter_op_thread_pool,
_In_ OrtThreadPool* intra_op_thread_pool,
_Outptr_ OrtSession** session
)NO_EXCEPTION;
/**
* SessionGetExecutionProvider
* This api is used to get a handle to an OrtExecutionProvider.
* Currently WinML uses this to talk directly to the DML EP and configure settings on it.
*/
OrtStatus*(ORT_API_CALL* SessionGetExecutionProvider)(
_In_ OrtSession* session, _In_ size_t index, _Out_ OrtExecutionProvider** provider
)NO_EXCEPTION;
/**
* SessionInitialize
* This api is used to initialize an OrtSession. This is one component of creating a usable OrtSession, and is a part of CreateSession in the c-abi.
* Currently WinML uses this to finalize session creation, after configuring a variety of properties on the OrtSession.
*/
OrtStatus*(ORT_API_CALL* SessionInitialize)(_In_ OrtSession* session)NO_EXCEPTION;
/**
* SessionRegisterGraphTransformers
* This api is used to enable DML specific graph transformations on an OrtSession.
*
* Ideally these transformations should be configured by the contract between the runtime and the EP and not overridden by WinML.
*/
OrtStatus*(ORT_API_CALL* SessionRegisterGraphTransformers)(_In_ OrtSession* session)NO_EXCEPTION;
/**
* SessionRegisterCustomRegistry
* This api is used to support custom operators as they were shipped in WinML RS5.
*/
OrtStatus*(ORT_API_CALL* SessionRegisterCustomRegistry)(_In_ OrtSession* session, _In_ IMLOperatorRegistry* registry)
NO_EXCEPTION;
/**
* SessionLoadAndPurloinModel
* This api is used to load an OrtModel into an OrtSession.
*
* Don't free the 'out' value as this API will defunct and release the OrtModel internally.
*/
OrtStatus*(ORT_API_CALL* SessionLoadAndPurloinModel)(_In_ OrtSession* session, _In_ OrtModel* model)NO_EXCEPTION;
/**
* SessionStartProfiling
* This api is used to start profiling OrtSession. The existing mechanism only allows configuring profiling at session creation.
*
* WinML uses this to toggle profilling on and off based on if a telemetry providers are being listened to.
*/
OrtStatus*(ORT_API_CALL* SessionStartProfiling)(_In_ OrtEnv* env, _In_ OrtSession* session)NO_EXCEPTION;
/**
* SessionEndProfiling
* This api is used to end profiling OrtSession. The existing mechanism only allows configuring profiling at session creation.
*
* WinML uses this to toggle profilling on and off based on if a telemetry providers are being listened to.
*/
OrtStatus*(ORT_API_CALL* SessionEndProfiling)(_In_ OrtSession* session)NO_EXCEPTION;
/**
* SessionCopyOneInputAcrossDevices
* This api is used to copy and create an OrtValue input to prepare the input on the correct device.
*
* WinML uses this to copy gpu device OrtValues to the CPU and vice-versa.
*/
OrtStatus*(ORT_API_CALL* SessionCopyOneInputAcrossDevices)(
_In_ OrtSession* session,
_In_ const char* const input_name,
_In_ OrtValue* orig_value,
_Outptr_ OrtValue** new_value
)NO_EXCEPTION;
// Dml methods (TODO need to figure out how these need to move to session somehow...)
/**
* SessionGetNumberOfIntraOpThreads
* This api returns the number of intra operator threads set on the OrtSession.
*
* WinML uses this to determine that the correct number of threads was set correctly through OrtSessionOptions.
*/
OrtStatus*(ORT_API_CALL* SessionGetNumberOfIntraOpThreads)(_In_ OrtSession* session, _Out_ uint32_t* num_threads)
NO_EXCEPTION;
/**
* SessionGetIntrapOpThreadSpinning
* This api returns false if the ort session options config entry "session.intra_op.allow_spinning" is set to "0", and true otherwise
*
* WinML uses this to determine that the intra op thread spin policy was set correctly through OrtSessionOptions
*/
OrtStatus*(ORT_API_CALL* SessionGetIntraOpThreadSpinning)(_In_ OrtSession* session, _Out_ bool* allow_spinning)
NO_EXCEPTION;
/**
* SessionGetNamedDimensionsOverrides
* This api returns the named dimension overrides that are specified for this session
*
* WinML uses this to determine that named dimension overrides were set correctly through OrtSessionOptions.
*/
OrtStatus*(ORT_API_CALL* SessionGetNamedDimensionsOverrides)(
_In_ OrtSession* session,
_Out_ winrt::Windows::Foundation::Collections::IMapView<winrt::hstring, uint32_t>& overrides
)NO_EXCEPTION;
/**
* DmlExecutionProviderFlushContext
* This api is used to flush the DML EP.
*
* WinML communicates directly with DML to perform this as an optimization.
*/
OrtStatus*(ORT_API_CALL* DmlExecutionProviderFlushContext)(_In_ OrtExecutionProvider* dml_provider)NO_EXCEPTION;
/**
* DmlExecutionProviderReleaseCompletedReferences
* This api is used to release completed references after first run the DML EP.
*
* WinML communicates directly with DML to perform this as an optimization.
*/
OrtStatus*(ORT_API_CALL* DmlExecutionProviderReleaseCompletedReferences)(_In_ OrtExecutionProvider* dml_provider
)NO_EXCEPTION;
/**
* DmlCopyTensor
* This api is used copy a tensor allocated by the DML EP Allocator to the CPU.
*
* WinML uses this when graphs are evaluated with DML, and their outputs remain on the GPU but need to be copied back to the CPU.
*/
OrtStatus*(ORT_API_CALL* DmlCopyTensor)(_In_ OrtExecutionProvider* provider, _In_ OrtValue* src, _In_ OrtValue* dst)
NO_EXCEPTION;
/**
* GetProviderMemoryInfo
* This api gets the memory info object associated with an EP.
*
* WinML uses this to manage caller specified D3D12 inputs/outputs. It uses the memory info here to call DmlCreateGPUAllocationFromD3DResource.
*/
OrtStatus*(ORT_API_CALL* GetProviderMemoryInfo)(_In_ OrtExecutionProvider* provider, OrtMemoryInfo** memory_info)
NO_EXCEPTION;
/**
* GetProviderAllocator
* This api gets associated allocator used by a provider.
*
* WinML uses this to create tensors, and needs to hold onto the allocator for the duration of the associated value's lifetime.
*/
OrtStatus*(ORT_API_CALL* GetProviderAllocator)(
_In_ OrtSession* session, _In_ OrtExecutionProvider* provider, OrtAllocator** allocator
)NO_EXCEPTION;
/**
* FreeProviderAllocator
* This api frees an allocator.
*
* WinML uses this to free the associated allocator for an ortvalue when creating tensors.
* Internally this derefs a shared_ptr.
*/
OrtStatus*(ORT_API_CALL* FreeProviderAllocator)(_In_ OrtAllocator* allocator)NO_EXCEPTION;
/**
* ExecutionProviderSync
* This api syncs the EP.
*
* WinML uses this to sync EP inputs/outputs directly.
*/
OrtStatus*(ORT_API_CALL* ExecutionProviderSync)(_In_ OrtExecutionProvider* provider)NO_EXCEPTION;
/**
* CreateCustomRegistry
* This api creates a custom registry that callers can populate with custom ops.
*
* WinML uses this to support custom ops.
*/
OrtStatus*(ORT_API_CALL* CreateCustomRegistry)(_Out_ IMLOperatorRegistry** registry)NO_EXCEPTION;
/**
* ValueGetDeviceId
* This api returns the device id of the OrtValue.
*
* WinML uses this to determine if an OrtValue is created on the needed device.
*/
OrtStatus*(ORT_API_CALL* ValueGetDeviceId)(_In_ OrtValue* ort_value, _Out_ int16_t* device_id)NO_EXCEPTION;
/**
* SessionGetInputRequiredDeviceId
* This api returns the required device id for a model input.
*
* WinML uses this to determine if an OrtValue is created on the needed device.
*/
OrtStatus*(ORT_API_CALL* SessionGetInputRequiredDeviceId)(
_In_ OrtSession* session, _In_ const char* const input_name, _Out_ int16_t* device_id
)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* CreateTensorTypeInfo)(
_In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, _Out_ OrtTypeInfo** type_info
)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* CreateSequenceTypeInfo)(_Out_ OrtTypeInfo** type_info)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* CreateMapTypeInfo)(_Out_ OrtTypeInfo** type_info)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* CreateModel)(_In_ int64_t opset, _Outptr_ OrtModel** out)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* ModelAddInput)(
_In_ OrtModel* model, _In_ const char* const input_name, _In_ OrtTypeInfo* info
)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* ModelAddConstantInput)(
_In_ OrtModel* model, _In_ const char* const input_name, _In_ OrtTypeInfo* info, _In_ OrtValue* value
)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* ModelAddOutput)(
_In_ OrtModel* model, _In_ const char* const output_name, _In_ OrtTypeInfo* info
)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* ModelAddOperator)(
_In_ OrtModel* model,
_In_ const char* const op_type,
_In_ const char* const op_name,
_In_ int64_t opset,
_In_ const char* const op_domain,
_In_ const char* const* input_names,
_In_ size_t num_inputs,
_In_ const char* const* output_names,
_In_ size_t num_outputs,
_In_ const char* const* attribute_names,
_In_ OrtValue** attribute_values,
_In_ size_t num_attributes
)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* ModelGetOpsetVersion)(
_In_ OrtModel* model, _In_ const char* const domain, _Out_ int32_t* version
)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* OperatorGetNumInputs)(
_In_ const char* const op_type, _In_ int64_t opset, _In_ const char* const op_domain, _Out_ size_t* num_inputs
)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* OperatorGetInputName)(
_In_ const char* const op_type,
_In_ int64_t opset,
_In_ const char* const op_domain,
_In_ size_t index,
_Out_ const char** const name
)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* OperatorGetNumOutputs)(
_In_ const char* const op_type, _In_ int64_t opset, _In_ const char* const op_domain, _Out_ size_t* num_inputs
)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* OperatorGetOutputName)(
_In_ const char* const op_type,
_In_ int64_t opset,
_In_ const char* const op_domain,
_In_ size_t index,
_Out_ const char** const name
)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* JoinModels)(
_In_ OrtModel* first_model,
_In_ OrtModel* second_model,
_In_ const char* const* output_names,
_In_ const char* const* input_names,
size_t num_linkages,
bool promote_unlinked_outputs,
_In_ const char* const join_node_prefix
)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* CreateThreadPool)(
_In_ ThreadPoolType type, _In_ OrtThreadPoolOptions* params, _Outptr_ OrtThreadPool** out
)NO_EXCEPTION;
ORT_CLASS_RELEASE(Model);
ORT_CLASS_RELEASE(ThreadPool);
};