diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 2f235f5f38..c5f92c9ef2 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -37,6 +37,7 @@ namespace onnxruntime { constexpr const char* CPU = "Cpu"; constexpr const char* CUDA = "Cuda"; constexpr const char* CUDA_PINNED = "CudaPinned"; +constexpr const char* DML = "DML"; constexpr const char* MIGRAPHX = "MIGraphX"; constexpr const char* MIGRAPHX_PINNED = "MIGraphXPinned"; diff --git a/include/onnxruntime/core/providers/dml/dml_provider_factory.h b/include/onnxruntime/core/providers/dml/dml_provider_factory.h index f55c4c09fa..308b7faa5f 100644 --- a/include/onnxruntime/core/providers/dml/dml_provider_factory.h +++ b/include/onnxruntime/core/providers/dml/dml_provider_factory.h @@ -29,7 +29,11 @@ extern "C" { * the adapter index. The device ID corresponds to the enumeration order of hardware adapters as given by * IDXGIFactory::EnumAdapters. A device_id of 0 always corresponds to the default adapter, which is typically the * primary display GPU installed on the system. A negative device_id is invalid. -*/ + * + * [[deprecated]] + * This export should be deprecated. + * The OrtSessionOptionsAppendExecutionProvider_DML export on the OrtDmlApi should be used instead. + */ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOptions* options, int device_id); /** @@ -40,10 +44,58 @@ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOpti * objects. * See also: DMLCreateDevice * See also: ID3D12Device::CreateCommandQueue + * + * [[deprecated]] + * This export should be deprecated. + * The OrtSessionOptionsAppendExecutionProvider_DML1 export on the OrtDmlApi should be used instead. */ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSessionOptions* options, _In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue); + +struct OrtDmlApi; +typedef struct OrtDmlApi OrtDmlApi; + +struct OrtDmlApi { + /** + * Creates a DirectML Execution Provider which executes on the hardware adapter with the given device_id, also known as + * the adapter index. The device ID corresponds to the enumeration order of hardware adapters as given by + * IDXGIFactory::EnumAdapters. A device_id of 0 always corresponds to the default adapter, which is typically the + * primary display GPU installed on the system. A negative device_id is invalid. + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOptions* options, int device_id); + + /** + * Creates a DirectML Execution Provider using the given DirectML device, and which executes work on the supplied D3D12 + * command queue. The DirectML device and D3D12 command queue must have the same parent ID3D12Device, or an error will + * be returned. The D3D12 command queue must be of type DIRECT or COMPUTE (see D3D12_COMMAND_LIST_TYPE). If this + * function succeeds, the inference session maintains a strong reference on both the dml_device and the command_queue + * objects. + * See also: DMLCreateDevice + * See also: ID3D12Device::CreateCommandQueue + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML1, _In_ OrtSessionOptions* options, + _In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue); + + /** + * CreateGPUAllocationFromD3DResource + * This api is used to create a DML EP input based on a user specified d3d12 resource. + */ + ORT_API2_STATUS(CreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* d3d_resource, _Out_ void** dml_resource); + + /** + * FreeGPUAllocation + * This api is used free the DML EP input created by CreateGPUAllocationFromD3DResource. + */ + ORT_API2_STATUS(FreeGPUAllocation, _In_ void* dml_resource); + + /** + * GetD3D12ResourceFromAllocation + * This api is used to get the D3D12 resource when a OrtValue has been allocated by the DML EP and accessed via GetMutableTensorData. + */ + ORT_API2_STATUS(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* provider, _In_ void* dml_allocation, _Out_ ID3D12Resource** d3d_resource); +}; + #ifdef __cplusplus } #endif diff --git a/include/onnxruntime/core/providers/winml/winml_provider_factory.h b/include/onnxruntime/core/providers/winml/winml_provider_factory.h index b08b42e310..19fd2f0ad7 100644 --- a/include/onnxruntime/core/providers/winml/winml_provider_factory.h +++ b/include/onnxruntime/core/providers/winml/winml_provider_factory.h @@ -1,9 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// The winml "provider factory" is not a true execution provider. +// It is placed here as an execution provider to facilitate the export of the WinMLAdapter API +// via the OrtGetWinMLAdapter method. + #include "onnxruntime_c_api.h" +struct OrtWinApi; +typedef struct OrtWinApi OrtWinApi; + struct WinmlAdapterApi; typedef struct WinmlAdapterApi WinmlAdapterApi; -ORT_EXPORT const WinmlAdapterApi* ORT_API_CALL OrtGetWinMLAdapter(_In_ const OrtApi* ort_api) NO_EXCEPTION; \ No newline at end of file +ORT_EXPORT const WinmlAdapterApi* ORT_API_CALL OrtGetWinMLAdapter(_In_ uint32_t ort_api_version) NO_EXCEPTION; \ No newline at end of file diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index b8bfdfc2e5..476364b5ad 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3061,7 +3061,6 @@ struct OrtApi { */ ORT_API2_STATUS(HasValue, _In_ const OrtValue* value, _Out_ int* out); /// @} - /// \name OrtKernelContext /// @{ /** \brief Used for custom operators, gets the GPU compute stream to use to launch the custom a GPU kernel @@ -3074,6 +3073,33 @@ struct OrtApi { * Only use it for custom kernel launching. */ ORT_API2_STATUS(KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, _Outptr_ void** out); + + /// @} + /// \name GetTensorMemoryInfo + /// @{ + /** \brief Returns a pointer to the ::OrtMemoryInfo of a Tensor + * \param[in] ort_value ::OrtValue containing tensor. + * \param[out] mem_info ::OrtMemoryInfo of the tensor. Do NOT free the returned pointer. It is valid for the lifetime of the ::OrtValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetTensorMemoryInfo, _In_ const OrtValue* value, _Out_ const OrtMemoryInfo** mem_info); + + /// @} + /// \name GetExecutionProviderApi + /// @{ + /** \brief Get a pointer to the requested version of the Execution Provider specific + * API extensions to the OrtApi + * \param[in] provider_name The name of the execution provider name. Currently only the following + * values are supported: "DML". + * \param[in] version Must be ::ORT_API_VERSION. + * \param[out] provider_api A void pointer containing a reference to the execution provider versioned api structure. + * For example, the provider_api pointer can be cast to the OrtDmlApi* when the provider_name is "DML". + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetExecutionProviderApi, _In_ const char* provider_name, _In_ uint32_t version, _Outptr_ const void** provider_api); + /// @} /// \name SessionOptions diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 28f6c79745..0c2b74b39b 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -916,6 +916,8 @@ struct CustomOpApi { template const T* GetTensorData(_Inout_ const OrtValue* value); + const OrtMemoryInfo* GetTensorMemoryInfo(_In_ const OrtValue* value); + std::vector GetTensorShape(const OrtTensorTypeAndShapeInfo* info); void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input); size_t KernelContext_GetInputCount(const OrtKernelContext* context); diff --git a/onnxruntime/contrib_ops/cpu/signal/dft.h b/onnxruntime/contrib_ops/cpu/signal/dft.h index 54608c271d..2b04781c70 100644 --- a/onnxruntime/contrib_ops/cpu/signal/dft.h +++ b/onnxruntime/contrib_ops/cpu/signal/dft.h @@ -10,7 +10,7 @@ class DFT final : public OpKernel { bool is_onesided_ = true; public: explicit DFT(const OpKernelInfo& info) : OpKernel(info) { - is_onesided_ = info.GetAttrOrDefault("onesided", 0); + is_onesided_ = static_cast(info.GetAttrOrDefault("onesided", 0)); } Status Compute(OpKernelContext* ctx) const override; }; @@ -26,7 +26,7 @@ class STFT final : public OpKernel { bool is_onesided_ = true; public: explicit STFT(const OpKernelInfo& info) : OpKernel(info) { - is_onesided_ = info.GetAttrOrDefault("onesided", 1); + is_onesided_ = static_cast(info.GetAttrOrDefault("onesided", 1)); } Status Compute(OpKernelContext* ctx) const override; }; diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index adc2718d85..a871befb6b 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -67,6 +67,10 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA *out = new OrtMemoryInfo( onnxruntime::CUDA_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast(id1)), id1, mem_type1); + } else if (strcmp(name1, onnxruntime::DML) == 0) { + *out = new OrtMemoryInfo( + onnxruntime::DML, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), + id1, mem_type1); } else { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified device is not supported."); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp index 6e7a32e7c0..8acec1b3bb 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp @@ -43,7 +43,7 @@ namespace Dml ) : onnxruntime::IAllocator( OrtMemoryInfo( - "DML allocator", + "DML", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0) ) diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index b82db265a5..9b3df2bf7f 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -12,6 +12,7 @@ using Microsoft::WRL::ComPtr; #include "core/providers/dml/dml_provider_factory.h" #include "core/session/abi_session_options_impl.h" +#include "core/session/allocator_adapters.h" #include "core/session/ort_apis.h" #include "core/framework/error_code_helper.h" #include "DmlExecutionProvider/src/ErrorHandling.h" @@ -141,6 +142,9 @@ std::shared_ptr CreateExecutionProviderFactory_DML(in } // namespace onnxruntime +// [[deprecated]] +// This export should be deprecated. +// The OrtSessionOptionsAppendExecutionProvider_DML export on the OrtDmlApi should be used instead. ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOptions* options, int device_id) { API_IMPL_BEGIN options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_DML(device_id)); @@ -148,6 +152,9 @@ API_IMPL_END return nullptr; } +// [[deprecated]] +// This export should be deprecated. +// The OrtSessionOptionsAppendExecutionProvider_DML1 export on the OrtDmlApi should be used instead. ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSessionOptions* options, _In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue) { API_IMPL_BEGIN @@ -156,3 +163,55 @@ API_IMPL_BEGIN API_IMPL_END return nullptr; } + +ORT_API_STATUS_IMPL(CreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* d3d_resource, _Out_ void** dml_resource) { + API_IMPL_BEGIN +#ifdef USE_DML + *dml_resource = Dml::CreateGPUAllocationFromD3DResource(d3d_resource); +#else + *dml_resource = nullptr; +#endif // USE_DML + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(FreeGPUAllocation, _In_ void* ptr) { + API_IMPL_BEGIN +#ifdef USE_DML + Dml::FreeGPUAllocation(ptr); +#endif // USE_DML + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* ort_allocator, _In_ void* allocation, _Out_ ID3D12Resource** d3d_resource) { + API_IMPL_BEGIN +#ifdef USE_DML + auto wrapping_allocator = static_cast(ort_allocator); + auto allocator = wrapping_allocator->GetWrappedIAllocator(); + if (!allocator) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available"); + } + *d3d_resource = Dml::GetD3D12ResourceFromAllocation(allocator.get(), allocation); + (*d3d_resource)->AddRef(); +#else + *d3d_resource = nullptr; +#endif // USE_DML + return nullptr; + API_IMPL_END +} + +static constexpr OrtDmlApi ort_dml_api_10_to_x = { + &OrtSessionOptionsAppendExecutionProvider_DML, + &OrtSessionOptionsAppendExecutionProviderEx_DML, + &CreateGPUAllocationFromD3DResource, + &FreeGPUAllocation, + &GetD3D12ResourceFromAllocation +}; + +const OrtDmlApi* GetOrtDmlApi(_In_ uint32_t /*version*/) NO_EXCEPTION { +#ifdef USE_DML + return &ort_dml_api_10_to_x; +#endif // USE_DML + return nullptr; +} \ No newline at end of file diff --git a/onnxruntime/core/session/allocator_adapters.cc b/onnxruntime/core/session/allocator_adapters.cc index 1053d0d1e8..a3a8cb273a 100644 --- a/onnxruntime/core/session/allocator_adapters.cc +++ b/onnxruntime/core/session/allocator_adapters.cc @@ -31,6 +31,10 @@ const OrtMemoryInfo* OrtAllocatorImplWrappingIAllocator::Info() const { return &i_allocator_->Info(); } +onnxruntime::AllocatorPtr OrtAllocatorImplWrappingIAllocator::GetWrappedIAllocator() { + return i_allocator_; +} + IAllocatorImplWrappingOrtAllocator::IAllocatorImplWrappingOrtAllocator(OrtAllocator* ort_allocator) : IAllocator(*ort_allocator->Info(ort_allocator)), ort_allocator_(ort_allocator) {} diff --git a/onnxruntime/core/session/allocator_adapters.h b/onnxruntime/core/session/allocator_adapters.h index 6b39dd98d1..587e9f733c 100644 --- a/onnxruntime/core/session/allocator_adapters.h +++ b/onnxruntime/core/session/allocator_adapters.h @@ -30,6 +30,8 @@ struct OrtAllocatorImplWrappingIAllocator final : public OrtAllocatorImpl { ORT_DISALLOW_COPY_AND_ASSIGNMENT(OrtAllocatorImplWrappingIAllocator); + onnxruntime::AllocatorPtr GetWrappedIAllocator(); + private: onnxruntime::AllocatorPtr i_allocator_; }; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index be11a37c4e..4fbdcac248 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -45,6 +45,11 @@ ProviderInfo_CUDA* TryGetProviderInfo_CUDA(); } #endif +#ifdef USE_DML +#include "core/providers/dml/dml_provider_factory.h" +const OrtDmlApi* GetOrtDmlApi(_In_ uint32_t version) NO_EXCEPTION; +#endif + #ifdef ENABLE_EXTENSION_CUSTOM_OPS #include "onnxruntime_extensions.h" #endif @@ -1911,6 +1916,27 @@ ORT_API_STATUS_IMPL(OrtApis::ReleaseAvailableProviders, _In_ char** ptr, return NULL; } +ORT_API_STATUS_IMPL(OrtApis::GetExecutionProviderApi, + [[maybe_unused]] _In_ const char* provider_name, + [[maybe_unused]] _In_ uint32_t version, + _Outptr_ const void** provider_api) { + API_IMPL_BEGIN + + *provider_api = nullptr; +#ifdef USE_DML + if (strcmp(provider_name, "DML") == 0) { + *provider_api = GetOrtDmlApi(version); + if (*provider_api == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified version is not supported for the DirectML provider."); + } + return NULL; + } +#endif + + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified provider is not supported."); + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::TensorAt, _Inout_ OrtValue* value, const int64_t* location_values, size_t location_values_count, _Outptr_ void** out) { TENSOR_READWRITE_API_BEGIN @@ -2082,6 +2108,13 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArrayWithPrepackedWeightsContainer API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::GetTensorMemoryInfo, _In_ const OrtValue* value, _Outptr_ const OrtMemoryInfo** memory_info) { + TENSOR_READ_API_BEGIN + *memory_info = &tensor.Location(); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomCreateThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn) { API_IMPL_BEGIN options->value.custom_create_thread_fn = ort_custom_create_thread_fn; @@ -2379,6 +2412,8 @@ static constexpr OrtApi ort_api_1_to_10 = { // Version 10 - In development, feel free to add/remove/rearrange here &OrtApis::HasValue, &OrtApis::KernelContext_GetGPUComputeStream, + &OrtApis::GetTensorMemoryInfo, + &OrtApis::GetExecutionProviderApi, &OrtApis::SessionOptionsSetCustomCreateThreadFn, &OrtApis::SessionOptionsSetCustomThreadCreationOptions, &OrtApis::SessionOptionsSetCustomJoinThreadFn, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 561870d235..5053b8ac5f 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -317,6 +317,8 @@ ORT_API_STATUS_IMPL(GetSparseTensorValues, _In_ const OrtValue* ort_value, _Outp ORT_API_STATUS_IMPL(GetSparseTensorIndicesTypeShape, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Outptr_ OrtTensorTypeAndShapeInfo** out); ORT_API_STATUS_IMPL(GetSparseTensorIndices, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Out_ size_t* num_indices, _Outptr_ const void** indices); ORT_API_STATUS_IMPL(KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, _Outptr_ void** out); +ORT_API_STATUS_IMPL(GetTensorMemoryInfo, _In_ const OrtValue* value, _Outptr_ const OrtMemoryInfo** memory_info); +ORT_API_STATUS_IMPL(GetExecutionProviderApi, _In_ const char* provider_name, _In_ uint32_t version, _Outptr_ const void** provider_api); ORT_API_STATUS_IMPL(SessionOptionsSetCustomCreateThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn); ORT_API_STATUS_IMPL(SessionOptionsSetCustomThreadCreationOptions, _Inout_ OrtSessionOptions* options, _In_ void* ort_custom_thread_creation_options); ORT_API_STATUS_IMPL(SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn); diff --git a/winml/adapter/winml_adapter_apis.h b/winml/adapter/winml_adapter_apis.h index a8f1690c9c..079b650115 100644 --- a/winml/adapter/winml_adapter_apis.h +++ b/winml/adapter/winml_adapter_apis.h @@ -63,16 +63,12 @@ ORT_API_STATUS(SessionGetNamedDimensionsOverrides, _In_ OrtSession* session, _Ou ORT_API_STATUS(DmlExecutionProviderSetDefaultRoundingMode, _In_ OrtExecutionProvider* dml_provider, _In_ bool is_enabled); ORT_API_STATUS(DmlExecutionProviderFlushContext, _In_ OrtExecutionProvider* dml_provider); ORT_API_STATUS(DmlExecutionProviderReleaseCompletedReferences, _In_ OrtExecutionProvider* dml_provider); -ORT_API_STATUS(DmlCreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* pResource, _Out_ void** dml_resource); -ORT_API_STATUS(DmlGetD3D12ResourceFromAllocation, _In_ OrtExecutionProvider* provider, _In_ void* allocation, _Out_ ID3D12Resource** resource); -ORT_API_STATUS(DmlFreeGPUAllocation, _In_ void* ptr); // note: this returns a weak ref ORT_API_STATUS(GetProviderMemoryInfo, _In_ OrtExecutionProvider* provider, OrtMemoryInfo** memory_info); ORT_API_STATUS(GetProviderAllocator, _In_ OrtExecutionProvider* provider, OrtAllocator** allocator); ORT_API_STATUS(FreeProviderAllocator, _In_ OrtAllocator* allocator); -ORT_API_STATUS(GetValueMemoryInfo, const OrtValue* value, OrtMemoryInfo** memory_info); // ExecutionProvider Methods ORT_API_STATUS(ExecutionProviderSync, _In_ OrtExecutionProvider* provider); diff --git a/winml/adapter/winml_adapter_c_api.cpp b/winml/adapter/winml_adapter_c_api.cpp index f6c99338b8..d3b29630e1 100644 --- a/winml/adapter/winml_adapter_c_api.cpp +++ b/winml/adapter/winml_adapter_c_api.cpp @@ -64,15 +64,11 @@ static constexpr WinmlAdapterApi winml_adapter_api_1 = { &winmla::DmlExecutionProviderSetDefaultRoundingMode, &winmla::DmlExecutionProviderFlushContext, &winmla::DmlExecutionProviderReleaseCompletedReferences, - &winmla::DmlCreateGPUAllocationFromD3DResource, - &winmla::DmlFreeGPUAllocation, - &winmla::DmlGetD3D12ResourceFromAllocation, &winmla::DmlCopyTensor, &winmla::GetProviderMemoryInfo, &winmla::GetProviderAllocator, &winmla::FreeProviderAllocator, - &winmla::GetValueMemoryInfo, &winmla::ExecutionProviderSync, @@ -100,10 +96,10 @@ static constexpr WinmlAdapterApi winml_adapter_api_1 = { &winmla::ReleaseModel }; -const WinmlAdapterApi* ORT_API_CALL OrtGetWinMLAdapter(_In_ const OrtApi* ort_api) NO_EXCEPTION { - if (OrtApis::GetApi(2) == ort_api) { +const WinmlAdapterApi* ORT_API_CALL OrtGetWinMLAdapter(_In_ uint32_t ort_version) NO_EXCEPTION { + if (ort_version >= 2) { return &winml_adapter_api_1; } return nullptr; -} +} \ No newline at end of file diff --git a/winml/adapter/winml_adapter_c_api.h b/winml/adapter/winml_adapter_c_api.h index 27c923a6ce..ce8f33f00c 100644 --- a/winml/adapter/winml_adapter_c_api.h +++ b/winml/adapter/winml_adapter_c_api.h @@ -320,9 +320,9 @@ struct WinmlAdapterApi { /** * DmlExecutionProviderSetDefaultRoundingMode - * This api is used to configure the DML EP to turn on/off rounding. + * This api is used to configure the DML EP to turn on/off rounding. * - * WinML uses this to disable rounding during session initialization and then enables it again post initialization. + * WinML uses this to disable rounding during session initialization and then enables it again post initialization. */ OrtStatus*(ORT_API_CALL* DmlExecutionProviderSetDefaultRoundingMode)(_In_ OrtExecutionProvider* dml_provider, _In_ bool is_enabled)NO_EXCEPTION; @@ -342,30 +342,6 @@ struct WinmlAdapterApi { */ OrtStatus*(ORT_API_CALL* DmlExecutionProviderReleaseCompletedReferences)(_In_ OrtExecutionProvider* dml_provider)NO_EXCEPTION; - /** - * DmlCreateGPUAllocationFromD3DResource - * This api is used to create a DML EP input based on a user specified d3d12 resource. - * - * WinML uses this as part of its Tensor apis to allow callers to specify their own D3D12 resources as inputs/outputs. - */ - OrtStatus*(ORT_API_CALL* DmlCreateGPUAllocationFromD3DResource)(_In_ ID3D12Resource* pResource, _Out_ void** dml_resource)NO_EXCEPTION; - - /** - * DmlFreeGPUAllocation - * This api is used free the DML EP input created by DmlCreateGPUAllocationFromD3DResource. - * - * WinML uses this as part of its Tensor apis to allow callers to specify their own D3D12 resources as inputs/outputs. - */ - OrtStatus*(ORT_API_CALL* DmlFreeGPUAllocation)(_In_ void* ptr)NO_EXCEPTION; - - /** - * DmlGetD3D12ResourceFromAllocation - * This api is used to get the D3D12 resource when a OrtValue has been allocated by the DML EP and accessed via GetMutableTensorData. - * - * WinML uses this in the image feature path to get the d3d resource and perform and tensorization on inputs directly into the allocated d3d12 resource. - */ - OrtStatus*(ORT_API_CALL* DmlGetD3D12ResourceFromAllocation)(_In_ OrtExecutionProvider* provider, _In_ void* allocation, _Out_ ID3D12Resource** resource)NO_EXCEPTION; - /** * DmlCopyTensor * This api is used copy a tensor allocated by the DML EP Allocator to the CPU. @@ -399,14 +375,6 @@ struct WinmlAdapterApi { */ OrtStatus*(ORT_API_CALL* FreeProviderAllocator)(_In_ OrtAllocator* allocator)NO_EXCEPTION; - /** - * GetValueMemoryInfo - * This api gets the memory info of an OrtValue. - * - * WinML uses this to determine if an OrtValue is allocated on the Cpu or elsewhere. - */ - OrtStatus*(ORT_API_CALL* GetValueMemoryInfo)(const OrtValue* value, OrtMemoryInfo** memory_info)NO_EXCEPTION; - /** * ExecutionProviderSync * This api syncs the EP. @@ -497,4 +465,4 @@ struct WinmlAdapterApi { _In_ const char* const join_node_prefix)NO_EXCEPTION; ORT_CLASS_RELEASE(Model); -}; +}; \ No newline at end of file diff --git a/winml/adapter/winml_adapter_dml.cpp b/winml/adapter/winml_adapter_dml.cpp index 1f1740236d..acda0d332b 100644 --- a/winml/adapter/winml_adapter_dml.cpp +++ b/winml/adapter/winml_adapter_dml.cpp @@ -126,42 +126,6 @@ ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderReleaseCompletedReferences, _In_ API_IMPL_END } -ORT_API_STATUS_IMPL(winmla::DmlCreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* pResource, _Out_ void** dml_resource) { - API_IMPL_BEGIN -#ifdef USE_DML - *dml_resource = Dml::CreateGPUAllocationFromD3DResource(pResource); -#else - *dml_resource = nullptr; -#endif // USE_DML USE_DML - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(winmla::DmlGetD3D12ResourceFromAllocation, _In_ OrtExecutionProvider* dml_provider, _In_ void* allocation, _Out_ ID3D12Resource** d3d_resource) { - API_IMPL_BEGIN -#ifdef USE_DML - auto dml_provider_internal = reinterpret_cast<::onnxruntime::IExecutionProvider*>(dml_provider); - *d3d_resource = - Dml::GetD3D12ResourceFromAllocation( - dml_provider_internal->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault).get(), - allocation); - (*d3d_resource)->AddRef(); -#else - *d3d_resource = nullptr; -#endif // USE_DML USE_DML - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(winmla::DmlFreeGPUAllocation, _In_ void* ptr) { - API_IMPL_BEGIN -#ifdef USE_DML - Dml::FreeGPUAllocation(ptr); -#endif // USE_DML USE_DML - return nullptr; - API_IMPL_END -} - ORT_API_STATUS_IMPL(winmla::DmlCopyTensor, _In_ OrtExecutionProvider* dml_provider, _In_ OrtValue* src, _In_ OrtValue* dst) { API_IMPL_BEGIN #ifdef USE_DML diff --git a/winml/adapter/winml_adapter_execution_provider.cpp b/winml/adapter/winml_adapter_execution_provider.cpp index 32343977c8..de39aed308 100644 --- a/winml/adapter/winml_adapter_execution_provider.cpp +++ b/winml/adapter/winml_adapter_execution_provider.cpp @@ -76,16 +76,4 @@ ORT_API_STATUS_IMPL(winmla::FreeProviderAllocator, _In_ OrtAllocator* allocator) delete static_cast(allocator); return nullptr; API_IMPL_END -} - -ORT_API_STATUS_IMPL(winmla::GetValueMemoryInfo, const OrtValue* value, OrtMemoryInfo** memory_info) { - API_IMPL_BEGIN - const auto& tensor = value->Get(); - auto info = tensor.Location(); - *memory_info = new OrtMemoryInfo(info.name, info.alloc_type, info.device, info.id, info.mem_type); - if (*memory_info == nullptr) { - return OrtApis::CreateStatus(ORT_FAIL, "Out of memory"); - } - return nullptr; - API_IMPL_END -} +} \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeEngine.cpp b/winml/lib/Api.Ort/OnnxruntimeEngine.cpp index 996a06536b..e0a351de2a 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEngine.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeEngine.cpp @@ -11,6 +11,8 @@ #include "OnnxruntimeSessionBuilder.h" #include "OnnxruntimeErrors.h" +#include "core/providers/dml/dml_provider_factory.h" + using namespace _winml; static ONNXTensorElementDataType @@ -89,19 +91,17 @@ HRESULT OnnxruntimeValue::IsEmpty(bool* out) { HRESULT OnnxruntimeValue::IsCpu(bool* out) { auto ort_api = engine_->GetEngineFactory()->UseOrtApi(); - auto winml_adapter_api = engine_->GetEngineFactory()->UseWinmlAdapterApi(); - OrtMemoryInfo* ort_memory_info; - RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetValueMemoryInfo(value_.get(), &ort_memory_info), + const OrtMemoryInfo* ort_memory_info; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMemoryInfo(value_.get(), &ort_memory_info), ort_api); - auto memory_info = UniqueOrtMemoryInfo(ort_memory_info, ort_api->ReleaseMemoryInfo); const char* name; - RETURN_HR_IF_NOT_OK_MSG(ort_api->MemoryInfoGetName(memory_info.get(), &name), + RETURN_HR_IF_NOT_OK_MSG(ort_api->MemoryInfoGetName(ort_memory_info, &name), ort_api); OrtMemType type; - RETURN_HR_IF_NOT_OK_MSG(ort_api->MemoryInfoGetMemType(memory_info.get(), &type), + RETURN_HR_IF_NOT_OK_MSG(ort_api->MemoryInfoGetMemType(ort_memory_info, &type), ort_api); *out = !strcmp(name, "Cpu") || @@ -167,20 +167,28 @@ static auto GetStrings(const OrtApi* ort_api, const OrtValue* ort_value, HRESULT OnnxruntimeValue::GetResource(_winml::Resource& out) { auto ort_api = engine_->GetEngineFactory()->UseOrtApi(); - auto winml_adapter_api = engine_->GetEngineFactory()->UseWinmlAdapterApi(); void* mutable_data = nullptr; RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(value_.get(), &mutable_data), ort_api); - - OrtExecutionProvider* ort_provider; - RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(engine_->UseOrtSession(), 0, &ort_provider), + + const OrtMemoryInfo* ort_memory_info; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMemoryInfo(value_.get(), &ort_memory_info), ort_api); bool is_cpu = false; if (SUCCEEDED(IsCpu(&is_cpu)) && !is_cpu) { + const OrtDmlApi* ort_dml_api; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast(&ort_dml_api)), + ort_api); + + OrtAllocator* ort_allocator; + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateAllocator(engine_->UseOrtSession(), ort_memory_info, &ort_allocator), + ort_api); + auto allocator = UniqueOrtAllocator(ort_allocator, ort_api->ReleaseAllocator); + winrt::com_ptr resource; - RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlGetD3D12ResourceFromAllocation(ort_provider, mutable_data, + RETURN_HR_IF_NOT_OK_MSG(ort_dml_api->GetD3D12ResourceFromAllocation(allocator.get(), mutable_data, resource.put()), ort_api); out = _winml::Resource(resource.get(), [](void*) { /*do nothing, as this pointer is actually a com pointer! */ }); @@ -578,7 +586,11 @@ HRESULT OnnxruntimeEngine::CreateTensorValue(const int64_t* shape, size_t count, RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetProviderAllocator(ort_provider, &ort_allocator), engine_factory_->UseOrtApi()); - auto unique_allocator = UniqueOrtAllocator(ort_allocator, winml_adapter_api->FreeProviderAllocator); // the release here should probably not return anything + auto unique_allocator = UniqueOrtAllocator( + ort_allocator, + [](OrtAllocator* allocator) { + GetVersionedWinmlAdapterApi()->FreeProviderAllocator(allocator); + }); // the release here should probably not return anything OrtValue* ort_value; RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateTensorAsOrtValue(unique_allocator.get(), shape, count, ONNXTensorElementDataTypeFromTensorKind(kind), &ort_value), @@ -612,30 +624,35 @@ class DmlAllocatorWrapper : public Microsoft::WRL::RuntimeClass< */ HRESULT OnnxruntimeEngine::CreateTensorValueFromExternalD3DResource(ID3D12Resource* d3d_resource, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) { auto ort_api = engine_factory_->UseOrtApi(); - auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + const OrtDmlApi* ort_dml_api; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast(&ort_dml_api)), + ort_api); - OrtExecutionProvider* ort_provider; - RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), - engine_factory_->UseOrtApi()); + OrtMemoryInfo* ort_memory_info; + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateMemoryInfo("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault, &ort_memory_info), + ort_api); - OrtMemoryInfo* dml_memory = nullptr; - RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetProviderMemoryInfo(ort_provider, &dml_memory), - engine_factory_->UseOrtApi()); + OrtAllocator* ort_allocator; + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateAllocator(session_.get(), ort_memory_info, &ort_allocator), + ort_api); + auto allocator = UniqueOrtAllocator(ort_allocator, ort_api->ReleaseAllocator); void* dml_allocator_resource; - RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlCreateGPUAllocationFromD3DResource(d3d_resource, &dml_allocator_resource), + RETURN_HR_IF_NOT_OK_MSG(ort_dml_api->CreateGPUAllocationFromD3DResource(d3d_resource, &dml_allocator_resource), engine_factory_->UseOrtApi()); auto unique_dml_allocator_resource = DmlAllocatorResource(dml_allocator_resource, [](void* ptr) { - GetVersionedWinmlAdapterApi()->DmlFreeGPUAllocation(ptr); + const OrtDmlApi* ort_dml_api; + GetVersionedOrtApi()->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast(&ort_dml_api)); + ort_dml_api->FreeGPUAllocation(ptr); }); // create the OrtValue as a tensor letting ort know that we own the data buffer OrtValue* ort_value; RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateTensorWithDataAsOrtValue( - dml_memory, + ort_memory_info, unique_dml_allocator_resource.get(), static_cast(d3d_resource->GetDesc().Width), shape, diff --git a/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp b/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp index 7654bdbbac..b47d1c211b 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp @@ -58,9 +58,7 @@ const OrtApi* _winml::GetVersionedOrtApi() { } const auto ort_api_base = ort_get_api_base_fn(); - - static const uint32_t ort_version = 2; - return ort_api_base->GetApi(ort_version); + return ort_api_base->GetApi(ORT_API_VERSION); } static const WinmlAdapterApi* GetVersionedWinmlAdapterApi(const OrtApi* ort_api) { @@ -73,7 +71,7 @@ static const WinmlAdapterApi* GetVersionedWinmlAdapterApi(const OrtApi* ort_api) FAIL_FAST_HR(HRESULT_FROM_WIN32(GetLastError())); } - return ort_get_winml_adapter_fn(ort_api); + return ort_get_winml_adapter_fn(ORT_API_VERSION); } const WinmlAdapterApi* _winml::GetVersionedWinmlAdapterApi() { diff --git a/winml/lib/Api.Ort/UniqueOrtPtr.h b/winml/lib/Api.Ort/UniqueOrtPtr.h index e1b24aae04..0c517623a9 100644 --- a/winml/lib/Api.Ort/UniqueOrtPtr.h +++ b/winml/lib/Api.Ort/UniqueOrtPtr.h @@ -8,7 +8,7 @@ #include "adapter/winml_adapter_c_api.h" using UniqueOrtModel = std::unique_ptr; -using UniqueOrtAllocator = std::unique_ptr; +using UniqueOrtAllocator = std::unique_ptr; using UniqueOrtSessionOptions = std::unique_ptr; using UniqueOrtSession = std::unique_ptr; using UniqueOrtValue = std::unique_ptr; @@ -17,3 +17,4 @@ using UniqueOrtTypeInfo = std::unique_ptr; using UniqueOrtRunOptions = std::unique_ptr; using UniqueOrtEnv = std::unique_ptr; + diff --git a/winml/test/adapter/AdapterDmlEpTest.cpp b/winml/test/adapter/AdapterDmlEpTest.cpp index ee744616a0..27230e076e 100644 --- a/winml/test/adapter/AdapterDmlEpTest.cpp +++ b/winml/test/adapter/AdapterDmlEpTest.cpp @@ -10,6 +10,7 @@ #include "winml_adapter_c_api.h" #include "core/framework/execution_provider.h" #include "core/providers/winml/winml_provider_factory.h" +#include "core/providers/dml/dml_provider_factory.h" #include "OnnxruntimeEngine.h" #include "OnnxruntimeErrors.h" #include "OnnxruntimeModel.h" @@ -17,15 +18,17 @@ namespace { const WinmlAdapterApi* winml_adapter_api; +const OrtDmlApi* ort_dml_api; const OrtApi* ort_api; OrtEnv* ort_env; void AdapterDmlEpTestSetup() { GPUTEST; winrt::init_apartment(); - ort_api = OrtGetApiBase()->GetApi(2); - winml_adapter_api = OrtGetWinMLAdapter(ort_api); - ort_api->CreateEnv(OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env); + ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + THROW_IF_NOT_OK_MSG(ort_api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast(&ort_dml_api)), ort_api); + winml_adapter_api = OrtGetWinMLAdapter(ORT_API_VERSION); + THROW_IF_NOT_OK_MSG(ort_api->CreateEnv(OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env), ort_api); #ifdef BUILD_INBOX winrt_activation_handler = WINRT_RoGetActivationFactory; #endif @@ -145,8 +148,8 @@ void DmlCreateAndFreeGPUAllocationFromD3DResource() { auto d3d12_resource = CreateD3D12Resource(*device); void* dml_allocator_resource; - THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlCreateGPUAllocationFromD3DResource(d3d12_resource.get(), &dml_allocator_resource), ort_api); - THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlFreeGPUAllocation(dml_allocator_resource), ort_api); + THROW_IF_NOT_OK_MSG(ort_dml_api->CreateGPUAllocationFromD3DResource(d3d12_resource.get(), &dml_allocator_resource), ort_api); + THROW_IF_NOT_OK_MSG(ort_dml_api->FreeGPUAllocation(dml_allocator_resource), ort_api); } void DmlGetD3D12ResourceFromAllocation() { @@ -156,68 +159,42 @@ void DmlGetD3D12ResourceFromAllocation() { auto d3d12_resource = CreateD3D12Resource(*device); void* gpu_allocation; - THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlCreateGPUAllocationFromD3DResource(d3d12_resource.get(), &gpu_allocation), ort_api); + THROW_IF_NOT_OK_MSG(ort_dml_api->CreateGPUAllocationFromD3DResource(d3d12_resource.get(), &gpu_allocation), ort_api); auto session = CreateDmlSession(); - OrtExecutionProvider* ort_provider; - THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api); + + OrtMemoryInfo* ort_memory_info; + THROW_IF_NOT_OK_MSG(ort_api->CreateMemoryInfo("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault, &ort_memory_info), ort_api); + + OrtAllocator* ort_allocator; + THROW_IF_NOT_OK_MSG(ort_api->CreateAllocator(session.get(), ort_memory_info, &ort_allocator), ort_api); + auto allocator = UniqueOrtAllocator(ort_allocator, ort_api->ReleaseAllocator); + winrt::com_ptr d3d12_resource_from_allocation; - THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlGetD3D12ResourceFromAllocation(ort_provider, gpu_allocation, d3d12_resource_from_allocation.put()), ort_api); + THROW_IF_NOT_OK_MSG(ort_dml_api->GetD3D12ResourceFromAllocation(allocator.get(), gpu_allocation, d3d12_resource_from_allocation.put()), ort_api); // Ensure resource is the same WINML_EXPECT_EQUAL(d3d12_resource, d3d12_resource_from_allocation); - THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlFreeGPUAllocation(gpu_allocation), ort_api); + THROW_IF_NOT_OK_MSG(ort_dml_api->FreeGPUAllocation(gpu_allocation), ort_api); } -UniqueOrtValue CreateTensorFromMemoryInfo(OrtMemoryInfo* memory_info) { +UniqueOrtValue CreateTensorFromMemoryInfo(const OrtMemoryInfo* memory_info) { OrtValue* tensor; THROW_IF_NOT_OK_MSG(ort_api->CreateTensorWithDataAsOrtValue(memory_info, tensor_values.data(), tensor_size * sizeof(float), dimensions.data(), dimensions.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &tensor), ort_api); return UniqueOrtValue(tensor, ort_api->ReleaseValue); } -void GetProviderMemoryInfo() { +void GetTensorMemoryInfo() { GPUTEST; auto session = CreateDmlSession(); - OrtExecutionProvider* ort_provider; - THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api); - OrtMemoryInfo *memory_info; - THROW_IF_NOT_OK_MSG(winml_adapter_api->GetProviderMemoryInfo(ort_provider, &memory_info), ort_api); - auto unique_memory_info = UniqueOrtMemoryInfo(memory_info, ort_api->ReleaseMemoryInfo); - // Ensure tensor can be created with the provided OrtMemoryInfo - CreateTensorFromMemoryInfo(unique_memory_info.get()); -} -void GetAndFreeProviderAllocator() { - GPUTEST; - auto session = CreateDmlSession(); - OrtExecutionProvider* ort_provider; - THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api); - OrtAllocator *allocator; - THROW_IF_NOT_OK_MSG(winml_adapter_api->GetProviderAllocator(ort_provider, &allocator), ort_api); - auto unique_allocator = UniqueOrtAllocator(allocator, winml_adapter_api->FreeProviderAllocator); + OrtMemoryInfo* ort_memory_info; + THROW_IF_NOT_OK_MSG(ort_api->CreateMemoryInfo("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault, &ort_memory_info), ort_api); + auto tensor = CreateTensorFromMemoryInfo(ort_memory_info); - // Ensure allocation works - void *data = nullptr; - THROW_IF_NOT_OK_MSG(ort_api->AllocatorAlloc(unique_allocator.get(), 1024, &data), ort_api); - WINML_EXPECT_NOT_EQUAL(nullptr, data); - THROW_IF_NOT_OK_MSG(ort_api->AllocatorFree(unique_allocator.get(), data), ort_api); -} - -void GetValueMemoryInfo() { - GPUTEST; - auto session = CreateDmlSession(); - OrtExecutionProvider* ort_provider; - THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api); - - OrtMemoryInfo *memory_info; - THROW_IF_NOT_OK_MSG(winml_adapter_api->GetProviderMemoryInfo(ort_provider, &memory_info), ort_api); - auto unique_memory_info = UniqueOrtMemoryInfo(memory_info, ort_api->ReleaseMemoryInfo); - auto tensor = CreateTensorFromMemoryInfo(unique_memory_info.get()); - - OrtMemoryInfo *value_memory_info; - THROW_IF_NOT_OK_MSG(winml_adapter_api->GetValueMemoryInfo(tensor.get(), &value_memory_info), ort_api); - auto unique_value_memory_info = UniqueOrtMemoryInfo(value_memory_info, ort_api->ReleaseMemoryInfo); - CreateTensorFromMemoryInfo(unique_value_memory_info.get()); + const OrtMemoryInfo* value_memory_info; + THROW_IF_NOT_OK_MSG(ort_api->GetTensorMemoryInfo(tensor.get(), &value_memory_info), ort_api); + CreateTensorFromMemoryInfo(value_memory_info); } void ExecutionProviderSync() { @@ -255,18 +232,17 @@ void DmlCopyTensor() { WINML_EXPECT_NOT_EQUAL(nullptr, winml_adapter_api->DmlCopyTensor(dml_provider, cpu_tensor.get(), dst_cpu_tensor.get())); // GPU to CPU - OrtMemoryInfo* dml_memory; - THROW_IF_NOT_OK_MSG(winml_adapter_api->GetProviderMemoryInfo(dml_provider, &dml_memory), ort_api); - auto unique_dml_memory = UniqueOrtMemoryInfo(dml_memory, ort_api->ReleaseMemoryInfo); + OrtMemoryInfo* ort_memory_info; + THROW_IF_NOT_OK_MSG(ort_api->CreateMemoryInfo("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault, &ort_memory_info), ort_api); auto resource = CreateD3D12Resource(*device); void* dml_allocator_resource; - THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlCreateGPUAllocationFromD3DResource(resource.get(), &dml_allocator_resource), ort_api); + THROW_IF_NOT_OK_MSG(ort_dml_api->CreateGPUAllocationFromD3DResource(resource.get(), &dml_allocator_resource), ort_api); std::array shape = {720, 720, 3}; OrtValue* gpu_value; THROW_IF_NOT_OK_MSG(ort_api->CreateTensorWithDataAsOrtValue( - dml_memory, + ort_memory_info, dml_allocator_resource, static_cast(resource->GetDesc().Width), shape.data(), @@ -277,7 +253,7 @@ void DmlCopyTensor() { dst_cpu_tensor = CreateTensorFromMemoryInfo(cpu_memory_info); THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlCopyTensor(dml_provider, gpu_value, dst_cpu_tensor.get()), ort_api); - THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlFreeGPUAllocation(dml_allocator_resource), ort_api); + THROW_IF_NOT_OK_MSG(ort_dml_api->FreeGPUAllocation(dml_allocator_resource), ort_api); } void CreateCustomRegistry() { @@ -290,18 +266,17 @@ void CreateCustomRegistry() { void ValueGetDeviceId() { GPUTEST; auto session = CreateDmlSession(); - OrtExecutionProvider* ort_provider; - THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api); - OrtMemoryInfo *memory_info; - THROW_IF_NOT_OK_MSG(winml_adapter_api->GetProviderMemoryInfo(ort_provider, &memory_info), ort_api); - auto gpu_tensor = CreateTensorFromMemoryInfo(memory_info); + + OrtMemoryInfo* ort_memory_info; + THROW_IF_NOT_OK_MSG(ort_api->CreateMemoryInfo("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault, &ort_memory_info), ort_api); + auto gpu_tensor = CreateTensorFromMemoryInfo(ort_memory_info); int16_t device_id; THROW_IF_NOT_OK_MSG(winml_adapter_api->ValueGetDeviceId(gpu_tensor.get(), &device_id), ort_api); OrtMemoryInfo* cpu_memory_info; THROW_IF_NOT_OK_MSG(ort_api->CreateCpuMemoryInfo(OrtDeviceAllocator, OrtMemTypeDefault, &cpu_memory_info), ort_api); - auto unique_cpu_memory_info = UniqueOrtMemoryInfo(memory_info, ort_api->ReleaseMemoryInfo); + auto unique_cpu_memory_info = UniqueOrtMemoryInfo(cpu_memory_info, ort_api->ReleaseMemoryInfo); auto cpu_tensor = CreateTensorFromMemoryInfo(unique_cpu_memory_info.get()); THROW_IF_NOT_OK_MSG(winml_adapter_api->ValueGetDeviceId(cpu_tensor.get(), &device_id), ort_api); WINML_EXPECT_EQUAL(0, device_id); @@ -329,9 +304,7 @@ const AdapterDmlEpTestApi& getapi() { DmlExecutionProviderReleaseCompletedReferences, DmlCreateAndFreeGPUAllocationFromD3DResource, DmlGetD3D12ResourceFromAllocation, - GetProviderMemoryInfo, - GetAndFreeProviderAllocator, - GetValueMemoryInfo, + GetTensorMemoryInfo, ExecutionProviderSync, DmlCopyTensor, CreateCustomRegistry, diff --git a/winml/test/adapter/AdapterDmlEpTest.h b/winml/test/adapter/AdapterDmlEpTest.h index c7b20ac88b..f08bcdc25a 100644 --- a/winml/test/adapter/AdapterDmlEpTest.h +++ b/winml/test/adapter/AdapterDmlEpTest.h @@ -11,9 +11,7 @@ struct AdapterDmlEpTestApi VoidTest DmlExecutionProviderReleaseCompletedReferences; VoidTest DmlCreateGPUAllocationFromD3DResource; VoidTest DmlCreateAndFreeGPUAllocationFromD3DResource; - VoidTest GetProviderMemoryInfo; - VoidTest GetAndFreeProviderAllocator; - VoidTest GetValueMemoryInfo; + VoidTest GetTensorMemoryInfo; VoidTest ExecutionProviderSync; VoidTest DmlCopyTensor; VoidTest CreateCustomRegistry; @@ -31,9 +29,7 @@ WINML_TEST(AdapterDmlEpTest, DmlExecutionProviderFlushContext) WINML_TEST(AdapterDmlEpTest, DmlExecutionProviderReleaseCompletedReferences) WINML_TEST(AdapterDmlEpTest, DmlCreateGPUAllocationFromD3DResource) WINML_TEST(AdapterDmlEpTest, DmlCreateAndFreeGPUAllocationFromD3DResource) -WINML_TEST(AdapterDmlEpTest, GetProviderMemoryInfo) -WINML_TEST(AdapterDmlEpTest, GetAndFreeProviderAllocator) -WINML_TEST(AdapterDmlEpTest, GetValueMemoryInfo) +WINML_TEST(AdapterDmlEpTest, GetTensorMemoryInfo) WINML_TEST(AdapterDmlEpTest, ExecutionProviderSync) WINML_TEST(AdapterDmlEpTest, DmlCopyTensor) WINML_TEST(AdapterDmlEpTest, CreateCustomRegistry) diff --git a/winml/test/adapter/AdapterSessionTest.cpp b/winml/test/adapter/AdapterSessionTest.cpp index 436eca32e3..a0d7cc2074 100644 --- a/winml/test/adapter/AdapterSessionTest.cpp +++ b/winml/test/adapter/AdapterSessionTest.cpp @@ -17,6 +17,7 @@ #include "core/common/logging/logging.h" #include "core/session/abi_session_options_impl.h" #include "core/session/ort_env.h" +#include "core/providers/dml/dml_provider_factory.h" using namespace _winml; using namespace winrt::Windows::Foundation::Collections; @@ -27,8 +28,8 @@ using namespace winrt::Windows::Storage::Streams; namespace { winrt::com_ptr<_winml::OnnxruntimeEngineFactory> engine_factory; -const OrtApi *ort_api; -const WinmlAdapterApi *winml_adapter_api; +const OrtApi* ort_api; +const WinmlAdapterApi* winml_adapter_api; OrtEnv* ort_env; void AdapterSessionTestSetup() { diff --git a/winml/test/adapter/adapter_test.cpp b/winml/test/adapter/adapter_test.cpp index f7e7906427..814db0fd12 100644 --- a/winml/test/adapter/adapter_test.cpp +++ b/winml/test/adapter/adapter_test.cpp @@ -11,8 +11,8 @@ static void AdapterTestSetup() { #ifdef BUILD_INBOX winrt_activation_handler = WINRT_RoGetActivationFactory; #endif - ort_api = OrtGetApiBase()->GetApi(2); - winml_adapter_api = OrtGetWinMLAdapter(ort_api); + ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + winml_adapter_api = OrtGetWinMLAdapter(ORT_API_VERSION); // for model tests std::wstring module_path = FileHelpers::GetModulePath();