mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Enable creating OrtValues from ID3D12Resources from the onnxruntime C-API (#9686)
* Add onnxruntime-windows api. * minor fixes * add to package headers * Build ort_dml_api for provider extensions. * Cleanup * misc comment * remove winml specific comments * use dml check in onnxruntime * Update include/onnxruntime/core/providers/dml/dml_provider_factory.h Co-authored-by: Dwayne Robinson <dwayner@microsoft.com> * Update include/onnxruntime/core/session/onnxruntime_c_api.h Co-authored-by: Dwayne Robinson <dwayner@microsoft.com> * Update include/onnxruntime/core/providers/dml/dml_provider_factory.h Co-authored-by: Dwayne Robinson <dwayner@microsoft.com> * Update include/onnxruntime/core/providers/dml/dml_provider_factory.h Co-authored-by: Dwayne Robinson <dwayner@microsoft.com> * Update onnxruntime/core/session/onnxruntime_c_api.cc Co-authored-by: Dwayne Robinson <dwayner@microsoft.com> * Update onnxruntime/core/session/ort_apis.h Co-authored-by: Dwayne Robinson <dwayner@microsoft.com> * Update winml/test/adapter/AdapterSessionTest.cpp Co-authored-by: Dwayne Robinson <dwayner@microsoft.com> * Update onnxruntime/core/session/onnxruntime_c_api.cc Co-authored-by: Dwayne Robinson <dwayner@microsoft.com> * Update winml/adapter/winml_adapter_c_api.cpp Co-authored-by: Dwayne Robinson <dwayner@microsoft.com> * Update include/onnxruntime/core/session/onnxruntime_c_api.h Co-authored-by: Pranav Sharma <prs@microsoft.com> * Update onnxruntime/core/session/onnxruntime_c_api.cc Co-authored-by: Pranav Sharma <prs@microsoft.com> * Update winml/adapter/winml_adapter_c_api.cpp * PR feedback * Update include/onnxruntime/core/providers/dml/dml_provider_factory.h Co-authored-by: Dwayne Robinson <dwayner@microsoft.com> * Update include/onnxruntime/core/providers/dml/dml_provider_factory.h Co-authored-by: Dwayne Robinson <dwayner@microsoft.com> * Update include/onnxruntime/core/providers/dml/dml_provider_factory.h Co-authored-by: Dwayne Robinson <dwayner@microsoft.com> * PR feedback * merge resolution and unreference param * (naming) Remove Dml prefix * maybe unused version * move DML code into DML path. CIs failing because DML is not available when --use_dml is not on * fix warning causing local build failures after merging * Change getvaluememoryinfo to gettensormemoryinfo * minor breaks * fix comment paste * fix comment Co-authored-by: Sheil Kumar <sheilk@microsoft.com> Co-authored-by: Dwayne Robinson <dwayner@microsoft.com> Co-authored-by: Pranav Sharma <prs@microsoft.com>
This commit is contained in:
parent
21eb747a0f
commit
3d0bd2596f
25 changed files with 295 additions and 203 deletions
|
|
@ -37,6 +37,7 @@ namespace onnxruntime {
|
|||
constexpr const char* CPU = "Cpu";
|
||||
constexpr const char* CUDA = "Cuda";
|
||||
constexpr const char* CUDA_PINNED = "CudaPinned";
|
||||
constexpr const char* DML = "DML";
|
||||
constexpr const char* MIGRAPHX = "MIGraphX";
|
||||
constexpr const char* MIGRAPHX_PINNED = "MIGraphXPinned";
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,11 @@ extern "C" {
|
|||
* the adapter index. The device ID corresponds to the enumeration order of hardware adapters as given by
|
||||
* IDXGIFactory::EnumAdapters. A device_id of 0 always corresponds to the default adapter, which is typically the
|
||||
* primary display GPU installed on the system. A negative device_id is invalid.
|
||||
*/
|
||||
*
|
||||
* [[deprecated]]
|
||||
* This export should be deprecated.
|
||||
* The OrtSessionOptionsAppendExecutionProvider_DML export on the OrtDmlApi should be used instead.
|
||||
*/
|
||||
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOptions* options, int device_id);
|
||||
|
||||
/**
|
||||
|
|
@ -40,10 +44,58 @@ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOpti
|
|||
* objects.
|
||||
* See also: DMLCreateDevice
|
||||
* See also: ID3D12Device::CreateCommandQueue
|
||||
*
|
||||
* [[deprecated]]
|
||||
* This export should be deprecated.
|
||||
* The OrtSessionOptionsAppendExecutionProvider_DML1 export on the OrtDmlApi should be used instead.
|
||||
*/
|
||||
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSessionOptions* options,
|
||||
_In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue);
|
||||
|
||||
|
||||
struct OrtDmlApi;
|
||||
typedef struct OrtDmlApi OrtDmlApi;
|
||||
|
||||
struct OrtDmlApi {
|
||||
/**
|
||||
* Creates a DirectML Execution Provider which executes on the hardware adapter with the given device_id, also known as
|
||||
* the adapter index. The device ID corresponds to the enumeration order of hardware adapters as given by
|
||||
* IDXGIFactory::EnumAdapters. A device_id of 0 always corresponds to the default adapter, which is typically the
|
||||
* primary display GPU installed on the system. A negative device_id is invalid.
|
||||
*/
|
||||
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOptions* options, int device_id);
|
||||
|
||||
/**
|
||||
* Creates a DirectML Execution Provider using the given DirectML device, and which executes work on the supplied D3D12
|
||||
* command queue. The DirectML device and D3D12 command queue must have the same parent ID3D12Device, or an error will
|
||||
* be returned. The D3D12 command queue must be of type DIRECT or COMPUTE (see D3D12_COMMAND_LIST_TYPE). If this
|
||||
* function succeeds, the inference session maintains a strong reference on both the dml_device and the command_queue
|
||||
* objects.
|
||||
* See also: DMLCreateDevice
|
||||
* See also: ID3D12Device::CreateCommandQueue
|
||||
*/
|
||||
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML1, _In_ OrtSessionOptions* options,
|
||||
_In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue);
|
||||
|
||||
/**
|
||||
* CreateGPUAllocationFromD3DResource
|
||||
* This api is used to create a DML EP input based on a user specified d3d12 resource.
|
||||
*/
|
||||
ORT_API2_STATUS(CreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* d3d_resource, _Out_ void** dml_resource);
|
||||
|
||||
/**
|
||||
* FreeGPUAllocation
|
||||
* This api is used free the DML EP input created by CreateGPUAllocationFromD3DResource.
|
||||
*/
|
||||
ORT_API2_STATUS(FreeGPUAllocation, _In_ void* dml_resource);
|
||||
|
||||
/**
|
||||
* GetD3D12ResourceFromAllocation
|
||||
* This api is used to get the D3D12 resource when a OrtValue has been allocated by the DML EP and accessed via GetMutableTensorData.
|
||||
*/
|
||||
ORT_API2_STATUS(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* provider, _In_ void* dml_allocation, _Out_ ID3D12Resource** d3d_resource);
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -1,9 +1,16 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// The winml "provider factory" is not a true execution provider.
|
||||
// It is placed here as an execution provider to facilitate the export of the WinMLAdapter API
|
||||
// via the OrtGetWinMLAdapter method.
|
||||
|
||||
#include "onnxruntime_c_api.h"
|
||||
|
||||
struct OrtWinApi;
|
||||
typedef struct OrtWinApi OrtWinApi;
|
||||
|
||||
struct WinmlAdapterApi;
|
||||
typedef struct WinmlAdapterApi WinmlAdapterApi;
|
||||
|
||||
ORT_EXPORT const WinmlAdapterApi* ORT_API_CALL OrtGetWinMLAdapter(_In_ const OrtApi* ort_api) NO_EXCEPTION;
|
||||
ORT_EXPORT const WinmlAdapterApi* ORT_API_CALL OrtGetWinMLAdapter(_In_ uint32_t ort_api_version) NO_EXCEPTION;
|
||||
|
|
@ -3061,7 +3061,6 @@ struct OrtApi {
|
|||
*/
|
||||
ORT_API2_STATUS(HasValue, _In_ const OrtValue* value, _Out_ int* out);
|
||||
/// @}
|
||||
|
||||
/// \name OrtKernelContext
|
||||
/// @{
|
||||
/** \brief Used for custom operators, gets the GPU compute stream to use to launch the custom a GPU kernel
|
||||
|
|
@ -3074,6 +3073,33 @@ struct OrtApi {
|
|||
* Only use it for custom kernel launching.
|
||||
*/
|
||||
ORT_API2_STATUS(KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, _Outptr_ void** out);
|
||||
|
||||
/// @}
|
||||
/// \name GetTensorMemoryInfo
|
||||
/// @{
|
||||
/** \brief Returns a pointer to the ::OrtMemoryInfo of a Tensor
|
||||
* \param[in] ort_value ::OrtValue containing tensor.
|
||||
* \param[out] mem_info ::OrtMemoryInfo of the tensor. Do NOT free the returned pointer. It is valid for the lifetime of the ::OrtValue
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*/
|
||||
ORT_API2_STATUS(GetTensorMemoryInfo, _In_ const OrtValue* value, _Out_ const OrtMemoryInfo** mem_info);
|
||||
|
||||
/// @}
|
||||
/// \name GetExecutionProviderApi
|
||||
/// @{
|
||||
/** \brief Get a pointer to the requested version of the Execution Provider specific
|
||||
* API extensions to the OrtApi
|
||||
* \param[in] provider_name The name of the execution provider name. Currently only the following
|
||||
* values are supported: "DML".
|
||||
* \param[in] version Must be ::ORT_API_VERSION.
|
||||
* \param[out] provider_api A void pointer containing a reference to the execution provider versioned api structure.
|
||||
* For example, the provider_api pointer can be cast to the OrtDmlApi* when the provider_name is "DML".
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*/
|
||||
ORT_API2_STATUS(GetExecutionProviderApi, _In_ const char* provider_name, _In_ uint32_t version, _Outptr_ const void** provider_api);
|
||||
|
||||
/// @}
|
||||
|
||||
/// \name SessionOptions
|
||||
|
|
|
|||
|
|
@ -916,6 +916,8 @@ struct CustomOpApi {
|
|||
template <typename T>
|
||||
const T* GetTensorData(_Inout_ const OrtValue* value);
|
||||
|
||||
const OrtMemoryInfo* GetTensorMemoryInfo(_In_ const OrtValue* value);
|
||||
|
||||
std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info);
|
||||
void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input);
|
||||
size_t KernelContext_GetInputCount(const OrtKernelContext* context);
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ class DFT final : public OpKernel {
|
|||
bool is_onesided_ = true;
|
||||
public:
|
||||
explicit DFT(const OpKernelInfo& info) : OpKernel(info) {
|
||||
is_onesided_ = info.GetAttrOrDefault<int64_t>("onesided", 0);
|
||||
is_onesided_ = static_cast<bool>(info.GetAttrOrDefault<int64_t>("onesided", 0));
|
||||
}
|
||||
Status Compute(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
|
@ -26,7 +26,7 @@ class STFT final : public OpKernel {
|
|||
bool is_onesided_ = true;
|
||||
public:
|
||||
explicit STFT(const OpKernelInfo& info) : OpKernel(info) {
|
||||
is_onesided_ = info.GetAttrOrDefault<int64_t>("onesided", 1);
|
||||
is_onesided_ = static_cast<bool>(info.GetAttrOrDefault<int64_t>("onesided", 1));
|
||||
}
|
||||
Status Compute(OpKernelContext* ctx) const override;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -67,6 +67,10 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA
|
|||
*out = new OrtMemoryInfo(
|
||||
onnxruntime::CUDA_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast<OrtDevice::DeviceId>(id1)),
|
||||
id1, mem_type1);
|
||||
} else if (strcmp(name1, onnxruntime::DML) == 0) {
|
||||
*out = new OrtMemoryInfo(
|
||||
onnxruntime::DML, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)),
|
||||
id1, mem_type1);
|
||||
} else {
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified device is not supported.");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ namespace Dml
|
|||
)
|
||||
: onnxruntime::IAllocator(
|
||||
OrtMemoryInfo(
|
||||
"DML allocator",
|
||||
"DML",
|
||||
OrtAllocatorType::OrtDeviceAllocator,
|
||||
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ using Microsoft::WRL::ComPtr;
|
|||
|
||||
#include "core/providers/dml/dml_provider_factory.h"
|
||||
#include "core/session/abi_session_options_impl.h"
|
||||
#include "core/session/allocator_adapters.h"
|
||||
#include "core/session/ort_apis.h"
|
||||
#include "core/framework/error_code_helper.h"
|
||||
#include "DmlExecutionProvider/src/ErrorHandling.h"
|
||||
|
|
@ -141,6 +142,9 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(in
|
|||
|
||||
} // namespace onnxruntime
|
||||
|
||||
// [[deprecated]]
|
||||
// This export should be deprecated.
|
||||
// The OrtSessionOptionsAppendExecutionProvider_DML export on the OrtDmlApi should be used instead.
|
||||
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOptions* options, int device_id) {
|
||||
API_IMPL_BEGIN
|
||||
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_DML(device_id));
|
||||
|
|
@ -148,6 +152,9 @@ API_IMPL_END
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// [[deprecated]]
|
||||
// This export should be deprecated.
|
||||
// The OrtSessionOptionsAppendExecutionProvider_DML1 export on the OrtDmlApi should be used instead.
|
||||
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSessionOptions* options,
|
||||
_In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue) {
|
||||
API_IMPL_BEGIN
|
||||
|
|
@ -156,3 +163,55 @@ API_IMPL_BEGIN
|
|||
API_IMPL_END
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(CreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* d3d_resource, _Out_ void** dml_resource) {
|
||||
API_IMPL_BEGIN
|
||||
#ifdef USE_DML
|
||||
*dml_resource = Dml::CreateGPUAllocationFromD3DResource(d3d_resource);
|
||||
#else
|
||||
*dml_resource = nullptr;
|
||||
#endif // USE_DML
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(FreeGPUAllocation, _In_ void* ptr) {
|
||||
API_IMPL_BEGIN
|
||||
#ifdef USE_DML
|
||||
Dml::FreeGPUAllocation(ptr);
|
||||
#endif // USE_DML
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* ort_allocator, _In_ void* allocation, _Out_ ID3D12Resource** d3d_resource) {
|
||||
API_IMPL_BEGIN
|
||||
#ifdef USE_DML
|
||||
auto wrapping_allocator = static_cast<onnxruntime::OrtAllocatorImplWrappingIAllocator*>(ort_allocator);
|
||||
auto allocator = wrapping_allocator->GetWrappedIAllocator();
|
||||
if (!allocator) {
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available");
|
||||
}
|
||||
*d3d_resource = Dml::GetD3D12ResourceFromAllocation(allocator.get(), allocation);
|
||||
(*d3d_resource)->AddRef();
|
||||
#else
|
||||
*d3d_resource = nullptr;
|
||||
#endif // USE_DML
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
static constexpr OrtDmlApi ort_dml_api_10_to_x = {
|
||||
&OrtSessionOptionsAppendExecutionProvider_DML,
|
||||
&OrtSessionOptionsAppendExecutionProviderEx_DML,
|
||||
&CreateGPUAllocationFromD3DResource,
|
||||
&FreeGPUAllocation,
|
||||
&GetD3D12ResourceFromAllocation
|
||||
};
|
||||
|
||||
const OrtDmlApi* GetOrtDmlApi(_In_ uint32_t /*version*/) NO_EXCEPTION {
|
||||
#ifdef USE_DML
|
||||
return &ort_dml_api_10_to_x;
|
||||
#endif // USE_DML
|
||||
return nullptr;
|
||||
}
|
||||
|
|
@ -31,6 +31,10 @@ const OrtMemoryInfo* OrtAllocatorImplWrappingIAllocator::Info() const {
|
|||
return &i_allocator_->Info();
|
||||
}
|
||||
|
||||
onnxruntime::AllocatorPtr OrtAllocatorImplWrappingIAllocator::GetWrappedIAllocator() {
|
||||
return i_allocator_;
|
||||
}
|
||||
|
||||
IAllocatorImplWrappingOrtAllocator::IAllocatorImplWrappingOrtAllocator(OrtAllocator* ort_allocator)
|
||||
: IAllocator(*ort_allocator->Info(ort_allocator)), ort_allocator_(ort_allocator) {}
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ struct OrtAllocatorImplWrappingIAllocator final : public OrtAllocatorImpl {
|
|||
|
||||
ORT_DISALLOW_COPY_AND_ASSIGNMENT(OrtAllocatorImplWrappingIAllocator);
|
||||
|
||||
onnxruntime::AllocatorPtr GetWrappedIAllocator();
|
||||
|
||||
private:
|
||||
onnxruntime::AllocatorPtr i_allocator_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -45,6 +45,11 @@ ProviderInfo_CUDA* TryGetProviderInfo_CUDA();
|
|||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_DML
|
||||
#include "core/providers/dml/dml_provider_factory.h"
|
||||
const OrtDmlApi* GetOrtDmlApi(_In_ uint32_t version) NO_EXCEPTION;
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_EXTENSION_CUSTOM_OPS
|
||||
#include "onnxruntime_extensions.h"
|
||||
#endif
|
||||
|
|
@ -1911,6 +1916,27 @@ ORT_API_STATUS_IMPL(OrtApis::ReleaseAvailableProviders, _In_ char** ptr,
|
|||
return NULL;
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::GetExecutionProviderApi,
|
||||
[[maybe_unused]] _In_ const char* provider_name,
|
||||
[[maybe_unused]] _In_ uint32_t version,
|
||||
_Outptr_ const void** provider_api) {
|
||||
API_IMPL_BEGIN
|
||||
|
||||
*provider_api = nullptr;
|
||||
#ifdef USE_DML
|
||||
if (strcmp(provider_name, "DML") == 0) {
|
||||
*provider_api = GetOrtDmlApi(version);
|
||||
if (*provider_api == nullptr) {
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified version is not supported for the DirectML provider.");
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
#endif
|
||||
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified provider is not supported.");
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::TensorAt, _Inout_ OrtValue* value, const int64_t* location_values, size_t location_values_count,
|
||||
_Outptr_ void** out) {
|
||||
TENSOR_READWRITE_API_BEGIN
|
||||
|
|
@ -2082,6 +2108,13 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArrayWithPrepackedWeightsContainer
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::GetTensorMemoryInfo, _In_ const OrtValue* value, _Outptr_ const OrtMemoryInfo** memory_info) {
|
||||
TENSOR_READ_API_BEGIN
|
||||
*memory_info = &tensor.Location();
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomCreateThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
|
||||
API_IMPL_BEGIN
|
||||
options->value.custom_create_thread_fn = ort_custom_create_thread_fn;
|
||||
|
|
@ -2379,6 +2412,8 @@ static constexpr OrtApi ort_api_1_to_10 = {
|
|||
// Version 10 - In development, feel free to add/remove/rearrange here
|
||||
&OrtApis::HasValue,
|
||||
&OrtApis::KernelContext_GetGPUComputeStream,
|
||||
&OrtApis::GetTensorMemoryInfo,
|
||||
&OrtApis::GetExecutionProviderApi,
|
||||
&OrtApis::SessionOptionsSetCustomCreateThreadFn,
|
||||
&OrtApis::SessionOptionsSetCustomThreadCreationOptions,
|
||||
&OrtApis::SessionOptionsSetCustomJoinThreadFn,
|
||||
|
|
|
|||
|
|
@ -317,6 +317,8 @@ ORT_API_STATUS_IMPL(GetSparseTensorValues, _In_ const OrtValue* ort_value, _Outp
|
|||
ORT_API_STATUS_IMPL(GetSparseTensorIndicesTypeShape, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Outptr_ OrtTensorTypeAndShapeInfo** out);
|
||||
ORT_API_STATUS_IMPL(GetSparseTensorIndices, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Out_ size_t* num_indices, _Outptr_ const void** indices);
|
||||
ORT_API_STATUS_IMPL(KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, _Outptr_ void** out);
|
||||
ORT_API_STATUS_IMPL(GetTensorMemoryInfo, _In_ const OrtValue* value, _Outptr_ const OrtMemoryInfo** memory_info);
|
||||
ORT_API_STATUS_IMPL(GetExecutionProviderApi, _In_ const char* provider_name, _In_ uint32_t version, _Outptr_ const void** provider_api);
|
||||
ORT_API_STATUS_IMPL(SessionOptionsSetCustomCreateThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn);
|
||||
ORT_API_STATUS_IMPL(SessionOptionsSetCustomThreadCreationOptions, _Inout_ OrtSessionOptions* options, _In_ void* ort_custom_thread_creation_options);
|
||||
ORT_API_STATUS_IMPL(SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn);
|
||||
|
|
|
|||
|
|
@ -63,16 +63,12 @@ ORT_API_STATUS(SessionGetNamedDimensionsOverrides, _In_ OrtSession* session, _Ou
|
|||
ORT_API_STATUS(DmlExecutionProviderSetDefaultRoundingMode, _In_ OrtExecutionProvider* dml_provider, _In_ bool is_enabled);
|
||||
ORT_API_STATUS(DmlExecutionProviderFlushContext, _In_ OrtExecutionProvider* dml_provider);
|
||||
ORT_API_STATUS(DmlExecutionProviderReleaseCompletedReferences, _In_ OrtExecutionProvider* dml_provider);
|
||||
ORT_API_STATUS(DmlCreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* pResource, _Out_ void** dml_resource);
|
||||
ORT_API_STATUS(DmlGetD3D12ResourceFromAllocation, _In_ OrtExecutionProvider* provider, _In_ void* allocation, _Out_ ID3D12Resource** resource);
|
||||
ORT_API_STATUS(DmlFreeGPUAllocation, _In_ void* ptr);
|
||||
|
||||
// note: this returns a weak ref
|
||||
|
||||
ORT_API_STATUS(GetProviderMemoryInfo, _In_ OrtExecutionProvider* provider, OrtMemoryInfo** memory_info);
|
||||
ORT_API_STATUS(GetProviderAllocator, _In_ OrtExecutionProvider* provider, OrtAllocator** allocator);
|
||||
ORT_API_STATUS(FreeProviderAllocator, _In_ OrtAllocator* allocator);
|
||||
ORT_API_STATUS(GetValueMemoryInfo, const OrtValue* value, OrtMemoryInfo** memory_info);
|
||||
|
||||
// ExecutionProvider Methods
|
||||
ORT_API_STATUS(ExecutionProviderSync, _In_ OrtExecutionProvider* provider);
|
||||
|
|
|
|||
|
|
@ -64,15 +64,11 @@ static constexpr WinmlAdapterApi winml_adapter_api_1 = {
|
|||
&winmla::DmlExecutionProviderSetDefaultRoundingMode,
|
||||
&winmla::DmlExecutionProviderFlushContext,
|
||||
&winmla::DmlExecutionProviderReleaseCompletedReferences,
|
||||
&winmla::DmlCreateGPUAllocationFromD3DResource,
|
||||
&winmla::DmlFreeGPUAllocation,
|
||||
&winmla::DmlGetD3D12ResourceFromAllocation,
|
||||
&winmla::DmlCopyTensor,
|
||||
|
||||
&winmla::GetProviderMemoryInfo,
|
||||
&winmla::GetProviderAllocator,
|
||||
&winmla::FreeProviderAllocator,
|
||||
&winmla::GetValueMemoryInfo,
|
||||
|
||||
&winmla::ExecutionProviderSync,
|
||||
|
||||
|
|
@ -100,10 +96,10 @@ static constexpr WinmlAdapterApi winml_adapter_api_1 = {
|
|||
&winmla::ReleaseModel
|
||||
};
|
||||
|
||||
const WinmlAdapterApi* ORT_API_CALL OrtGetWinMLAdapter(_In_ const OrtApi* ort_api) NO_EXCEPTION {
|
||||
if (OrtApis::GetApi(2) == ort_api) {
|
||||
const WinmlAdapterApi* ORT_API_CALL OrtGetWinMLAdapter(_In_ uint32_t ort_version) NO_EXCEPTION {
|
||||
if (ort_version >= 2) {
|
||||
return &winml_adapter_api_1;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
|
@ -320,9 +320,9 @@ struct WinmlAdapterApi {
|
|||
|
||||
/**
|
||||
* DmlExecutionProviderSetDefaultRoundingMode
|
||||
* This api is used to configure the DML EP to turn on/off rounding.
|
||||
* This api is used to configure the DML EP to turn on/off rounding.
|
||||
*
|
||||
* WinML uses this to disable rounding during session initialization and then enables it again post initialization.
|
||||
* WinML uses this to disable rounding during session initialization and then enables it again post initialization.
|
||||
*/
|
||||
OrtStatus*(ORT_API_CALL* DmlExecutionProviderSetDefaultRoundingMode)(_In_ OrtExecutionProvider* dml_provider, _In_ bool is_enabled)NO_EXCEPTION;
|
||||
|
||||
|
|
@ -342,30 +342,6 @@ struct WinmlAdapterApi {
|
|||
*/
|
||||
OrtStatus*(ORT_API_CALL* DmlExecutionProviderReleaseCompletedReferences)(_In_ OrtExecutionProvider* dml_provider)NO_EXCEPTION;
|
||||
|
||||
/**
|
||||
* DmlCreateGPUAllocationFromD3DResource
|
||||
* This api is used to create a DML EP input based on a user specified d3d12 resource.
|
||||
*
|
||||
* WinML uses this as part of its Tensor apis to allow callers to specify their own D3D12 resources as inputs/outputs.
|
||||
*/
|
||||
OrtStatus*(ORT_API_CALL* DmlCreateGPUAllocationFromD3DResource)(_In_ ID3D12Resource* pResource, _Out_ void** dml_resource)NO_EXCEPTION;
|
||||
|
||||
/**
|
||||
* DmlFreeGPUAllocation
|
||||
* This api is used free the DML EP input created by DmlCreateGPUAllocationFromD3DResource.
|
||||
*
|
||||
* WinML uses this as part of its Tensor apis to allow callers to specify their own D3D12 resources as inputs/outputs.
|
||||
*/
|
||||
OrtStatus*(ORT_API_CALL* DmlFreeGPUAllocation)(_In_ void* ptr)NO_EXCEPTION;
|
||||
|
||||
/**
|
||||
* DmlGetD3D12ResourceFromAllocation
|
||||
* This api is used to get the D3D12 resource when a OrtValue has been allocated by the DML EP and accessed via GetMutableTensorData.
|
||||
*
|
||||
* WinML uses this in the image feature path to get the d3d resource and perform and tensorization on inputs directly into the allocated d3d12 resource.
|
||||
*/
|
||||
OrtStatus*(ORT_API_CALL* DmlGetD3D12ResourceFromAllocation)(_In_ OrtExecutionProvider* provider, _In_ void* allocation, _Out_ ID3D12Resource** resource)NO_EXCEPTION;
|
||||
|
||||
/**
|
||||
* DmlCopyTensor
|
||||
* This api is used copy a tensor allocated by the DML EP Allocator to the CPU.
|
||||
|
|
@ -399,14 +375,6 @@ struct WinmlAdapterApi {
|
|||
*/
|
||||
OrtStatus*(ORT_API_CALL* FreeProviderAllocator)(_In_ OrtAllocator* allocator)NO_EXCEPTION;
|
||||
|
||||
/**
|
||||
* GetValueMemoryInfo
|
||||
* This api gets the memory info of an OrtValue.
|
||||
*
|
||||
* WinML uses this to determine if an OrtValue is allocated on the Cpu or elsewhere.
|
||||
*/
|
||||
OrtStatus*(ORT_API_CALL* GetValueMemoryInfo)(const OrtValue* value, OrtMemoryInfo** memory_info)NO_EXCEPTION;
|
||||
|
||||
/**
|
||||
* ExecutionProviderSync
|
||||
* This api syncs the EP.
|
||||
|
|
@ -497,4 +465,4 @@ struct WinmlAdapterApi {
|
|||
_In_ const char* const join_node_prefix)NO_EXCEPTION;
|
||||
|
||||
ORT_CLASS_RELEASE(Model);
|
||||
};
|
||||
};
|
||||
|
|
@ -126,42 +126,6 @@ ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderReleaseCompletedReferences, _In_
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(winmla::DmlCreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* pResource, _Out_ void** dml_resource) {
|
||||
API_IMPL_BEGIN
|
||||
#ifdef USE_DML
|
||||
*dml_resource = Dml::CreateGPUAllocationFromD3DResource(pResource);
|
||||
#else
|
||||
*dml_resource = nullptr;
|
||||
#endif // USE_DML USE_DML
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(winmla::DmlGetD3D12ResourceFromAllocation, _In_ OrtExecutionProvider* dml_provider, _In_ void* allocation, _Out_ ID3D12Resource** d3d_resource) {
|
||||
API_IMPL_BEGIN
|
||||
#ifdef USE_DML
|
||||
auto dml_provider_internal = reinterpret_cast<::onnxruntime::IExecutionProvider*>(dml_provider);
|
||||
*d3d_resource =
|
||||
Dml::GetD3D12ResourceFromAllocation(
|
||||
dml_provider_internal->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault).get(),
|
||||
allocation);
|
||||
(*d3d_resource)->AddRef();
|
||||
#else
|
||||
*d3d_resource = nullptr;
|
||||
#endif // USE_DML USE_DML
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(winmla::DmlFreeGPUAllocation, _In_ void* ptr) {
|
||||
API_IMPL_BEGIN
|
||||
#ifdef USE_DML
|
||||
Dml::FreeGPUAllocation(ptr);
|
||||
#endif // USE_DML USE_DML
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(winmla::DmlCopyTensor, _In_ OrtExecutionProvider* dml_provider, _In_ OrtValue* src, _In_ OrtValue* dst) {
|
||||
API_IMPL_BEGIN
|
||||
#ifdef USE_DML
|
||||
|
|
|
|||
|
|
@ -76,16 +76,4 @@ ORT_API_STATUS_IMPL(winmla::FreeProviderAllocator, _In_ OrtAllocator* allocator)
|
|||
delete static_cast<OrtAllocatorWrapper*>(allocator);
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(winmla::GetValueMemoryInfo, const OrtValue* value, OrtMemoryInfo** memory_info) {
|
||||
API_IMPL_BEGIN
|
||||
const auto& tensor = value->Get<onnxruntime::Tensor>();
|
||||
auto info = tensor.Location();
|
||||
*memory_info = new OrtMemoryInfo(info.name, info.alloc_type, info.device, info.id, info.mem_type);
|
||||
if (*memory_info == nullptr) {
|
||||
return OrtApis::CreateStatus(ORT_FAIL, "Out of memory");
|
||||
}
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
}
|
||||
|
|
@ -11,6 +11,8 @@
|
|||
#include "OnnxruntimeSessionBuilder.h"
|
||||
#include "OnnxruntimeErrors.h"
|
||||
|
||||
#include "core/providers/dml/dml_provider_factory.h"
|
||||
|
||||
using namespace _winml;
|
||||
|
||||
static ONNXTensorElementDataType
|
||||
|
|
@ -89,19 +91,17 @@ HRESULT OnnxruntimeValue::IsEmpty(bool* out) {
|
|||
|
||||
HRESULT OnnxruntimeValue::IsCpu(bool* out) {
|
||||
auto ort_api = engine_->GetEngineFactory()->UseOrtApi();
|
||||
auto winml_adapter_api = engine_->GetEngineFactory()->UseWinmlAdapterApi();
|
||||
|
||||
OrtMemoryInfo* ort_memory_info;
|
||||
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetValueMemoryInfo(value_.get(), &ort_memory_info),
|
||||
const OrtMemoryInfo* ort_memory_info;
|
||||
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMemoryInfo(value_.get(), &ort_memory_info),
|
||||
ort_api);
|
||||
auto memory_info = UniqueOrtMemoryInfo(ort_memory_info, ort_api->ReleaseMemoryInfo);
|
||||
|
||||
const char* name;
|
||||
RETURN_HR_IF_NOT_OK_MSG(ort_api->MemoryInfoGetName(memory_info.get(), &name),
|
||||
RETURN_HR_IF_NOT_OK_MSG(ort_api->MemoryInfoGetName(ort_memory_info, &name),
|
||||
ort_api);
|
||||
|
||||
OrtMemType type;
|
||||
RETURN_HR_IF_NOT_OK_MSG(ort_api->MemoryInfoGetMemType(memory_info.get(), &type),
|
||||
RETURN_HR_IF_NOT_OK_MSG(ort_api->MemoryInfoGetMemType(ort_memory_info, &type),
|
||||
ort_api);
|
||||
|
||||
*out = !strcmp(name, "Cpu") ||
|
||||
|
|
@ -167,20 +167,28 @@ static auto GetStrings(const OrtApi* ort_api, const OrtValue* ort_value,
|
|||
|
||||
HRESULT OnnxruntimeValue::GetResource(_winml::Resource& out) {
|
||||
auto ort_api = engine_->GetEngineFactory()->UseOrtApi();
|
||||
auto winml_adapter_api = engine_->GetEngineFactory()->UseWinmlAdapterApi();
|
||||
|
||||
void* mutable_data = nullptr;
|
||||
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(value_.get(), &mutable_data),
|
||||
ort_api);
|
||||
|
||||
OrtExecutionProvider* ort_provider;
|
||||
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(engine_->UseOrtSession(), 0, &ort_provider),
|
||||
|
||||
const OrtMemoryInfo* ort_memory_info;
|
||||
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMemoryInfo(value_.get(), &ort_memory_info),
|
||||
ort_api);
|
||||
|
||||
bool is_cpu = false;
|
||||
if (SUCCEEDED(IsCpu(&is_cpu)) && !is_cpu) {
|
||||
const OrtDmlApi* ort_dml_api;
|
||||
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&ort_dml_api)),
|
||||
ort_api);
|
||||
|
||||
OrtAllocator* ort_allocator;
|
||||
RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateAllocator(engine_->UseOrtSession(), ort_memory_info, &ort_allocator),
|
||||
ort_api);
|
||||
auto allocator = UniqueOrtAllocator(ort_allocator, ort_api->ReleaseAllocator);
|
||||
|
||||
winrt::com_ptr<ID3D12Resource> resource;
|
||||
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlGetD3D12ResourceFromAllocation(ort_provider, mutable_data,
|
||||
RETURN_HR_IF_NOT_OK_MSG(ort_dml_api->GetD3D12ResourceFromAllocation(allocator.get(), mutable_data,
|
||||
resource.put()),
|
||||
ort_api);
|
||||
out = _winml::Resource(resource.get(), [](void*) { /*do nothing, as this pointer is actually a com pointer! */ });
|
||||
|
|
@ -578,7 +586,11 @@ HRESULT OnnxruntimeEngine::CreateTensorValue(const int64_t* shape, size_t count,
|
|||
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetProviderAllocator(ort_provider, &ort_allocator),
|
||||
engine_factory_->UseOrtApi());
|
||||
|
||||
auto unique_allocator = UniqueOrtAllocator(ort_allocator, winml_adapter_api->FreeProviderAllocator); // the release here should probably not return anything
|
||||
auto unique_allocator = UniqueOrtAllocator(
|
||||
ort_allocator,
|
||||
[](OrtAllocator* allocator) {
|
||||
GetVersionedWinmlAdapterApi()->FreeProviderAllocator(allocator);
|
||||
}); // the release here should probably not return anything
|
||||
|
||||
OrtValue* ort_value;
|
||||
RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateTensorAsOrtValue(unique_allocator.get(), shape, count, ONNXTensorElementDataTypeFromTensorKind(kind), &ort_value),
|
||||
|
|
@ -612,30 +624,35 @@ class DmlAllocatorWrapper : public Microsoft::WRL::RuntimeClass<
|
|||
*/
|
||||
HRESULT OnnxruntimeEngine::CreateTensorValueFromExternalD3DResource(ID3D12Resource* d3d_resource, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) {
|
||||
auto ort_api = engine_factory_->UseOrtApi();
|
||||
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();
|
||||
const OrtDmlApi* ort_dml_api;
|
||||
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&ort_dml_api)),
|
||||
ort_api);
|
||||
|
||||
OrtExecutionProvider* ort_provider;
|
||||
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider),
|
||||
engine_factory_->UseOrtApi());
|
||||
OrtMemoryInfo* ort_memory_info;
|
||||
RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateMemoryInfo("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault, &ort_memory_info),
|
||||
ort_api);
|
||||
|
||||
OrtMemoryInfo* dml_memory = nullptr;
|
||||
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetProviderMemoryInfo(ort_provider, &dml_memory),
|
||||
engine_factory_->UseOrtApi());
|
||||
OrtAllocator* ort_allocator;
|
||||
RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateAllocator(session_.get(), ort_memory_info, &ort_allocator),
|
||||
ort_api);
|
||||
auto allocator = UniqueOrtAllocator(ort_allocator, ort_api->ReleaseAllocator);
|
||||
|
||||
void* dml_allocator_resource;
|
||||
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlCreateGPUAllocationFromD3DResource(d3d_resource, &dml_allocator_resource),
|
||||
RETURN_HR_IF_NOT_OK_MSG(ort_dml_api->CreateGPUAllocationFromD3DResource(d3d_resource, &dml_allocator_resource),
|
||||
engine_factory_->UseOrtApi());
|
||||
|
||||
auto unique_dml_allocator_resource =
|
||||
DmlAllocatorResource(dml_allocator_resource,
|
||||
[](void* ptr) {
|
||||
GetVersionedWinmlAdapterApi()->DmlFreeGPUAllocation(ptr);
|
||||
const OrtDmlApi* ort_dml_api;
|
||||
GetVersionedOrtApi()->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&ort_dml_api));
|
||||
ort_dml_api->FreeGPUAllocation(ptr);
|
||||
});
|
||||
|
||||
// create the OrtValue as a tensor letting ort know that we own the data buffer
|
||||
OrtValue* ort_value;
|
||||
RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateTensorWithDataAsOrtValue(
|
||||
dml_memory,
|
||||
ort_memory_info,
|
||||
unique_dml_allocator_resource.get(),
|
||||
static_cast<size_t>(d3d_resource->GetDesc().Width),
|
||||
shape,
|
||||
|
|
|
|||
|
|
@ -58,9 +58,7 @@ const OrtApi* _winml::GetVersionedOrtApi() {
|
|||
}
|
||||
|
||||
const auto ort_api_base = ort_get_api_base_fn();
|
||||
|
||||
static const uint32_t ort_version = 2;
|
||||
return ort_api_base->GetApi(ort_version);
|
||||
return ort_api_base->GetApi(ORT_API_VERSION);
|
||||
}
|
||||
|
||||
static const WinmlAdapterApi* GetVersionedWinmlAdapterApi(const OrtApi* ort_api) {
|
||||
|
|
@ -73,7 +71,7 @@ static const WinmlAdapterApi* GetVersionedWinmlAdapterApi(const OrtApi* ort_api)
|
|||
FAIL_FAST_HR(HRESULT_FROM_WIN32(GetLastError()));
|
||||
}
|
||||
|
||||
return ort_get_winml_adapter_fn(ort_api);
|
||||
return ort_get_winml_adapter_fn(ORT_API_VERSION);
|
||||
}
|
||||
|
||||
const WinmlAdapterApi* _winml::GetVersionedWinmlAdapterApi() {
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
#include "adapter/winml_adapter_c_api.h"
|
||||
|
||||
using UniqueOrtModel = std::unique_ptr<OrtModel, decltype(WinmlAdapterApi::ReleaseModel)>;
|
||||
using UniqueOrtAllocator = std::unique_ptr<OrtAllocator, decltype(WinmlAdapterApi::FreeProviderAllocator)>;
|
||||
using UniqueOrtAllocator = std::unique_ptr<OrtAllocator, decltype(OrtApi::ReleaseAllocator)>;
|
||||
using UniqueOrtSessionOptions = std::unique_ptr<OrtSessionOptions, decltype(OrtApi::ReleaseSessionOptions)>;
|
||||
using UniqueOrtSession = std::unique_ptr<OrtSession, decltype(OrtApi::ReleaseSession)>;
|
||||
using UniqueOrtValue = std::unique_ptr<OrtValue, decltype(OrtApi::ReleaseValue)>;
|
||||
|
|
@ -17,3 +17,4 @@ using UniqueOrtTypeInfo = std::unique_ptr<OrtTypeInfo,
|
|||
using UniqueOrtTensorTypeAndShapeInfo = std::unique_ptr<OrtTensorTypeAndShapeInfo, decltype(OrtApi::ReleaseTensorTypeAndShapeInfo)>;
|
||||
using UniqueOrtRunOptions = std::unique_ptr<OrtRunOptions, decltype(OrtApi::ReleaseRunOptions)>;
|
||||
using UniqueOrtEnv = std::unique_ptr<OrtEnv, decltype(OrtApi::ReleaseEnv)>;
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#include "winml_adapter_c_api.h"
|
||||
#include "core/framework/execution_provider.h"
|
||||
#include "core/providers/winml/winml_provider_factory.h"
|
||||
#include "core/providers/dml/dml_provider_factory.h"
|
||||
#include "OnnxruntimeEngine.h"
|
||||
#include "OnnxruntimeErrors.h"
|
||||
#include "OnnxruntimeModel.h"
|
||||
|
|
@ -17,15 +18,17 @@
|
|||
|
||||
namespace {
|
||||
const WinmlAdapterApi* winml_adapter_api;
|
||||
const OrtDmlApi* ort_dml_api;
|
||||
const OrtApi* ort_api;
|
||||
OrtEnv* ort_env;
|
||||
|
||||
void AdapterDmlEpTestSetup() {
|
||||
GPUTEST;
|
||||
winrt::init_apartment();
|
||||
ort_api = OrtGetApiBase()->GetApi(2);
|
||||
winml_adapter_api = OrtGetWinMLAdapter(ort_api);
|
||||
ort_api->CreateEnv(OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env);
|
||||
ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
|
||||
THROW_IF_NOT_OK_MSG(ort_api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&ort_dml_api)), ort_api);
|
||||
winml_adapter_api = OrtGetWinMLAdapter(ORT_API_VERSION);
|
||||
THROW_IF_NOT_OK_MSG(ort_api->CreateEnv(OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env), ort_api);
|
||||
#ifdef BUILD_INBOX
|
||||
winrt_activation_handler = WINRT_RoGetActivationFactory;
|
||||
#endif
|
||||
|
|
@ -145,8 +148,8 @@ void DmlCreateAndFreeGPUAllocationFromD3DResource() {
|
|||
|
||||
auto d3d12_resource = CreateD3D12Resource(*device);
|
||||
void* dml_allocator_resource;
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlCreateGPUAllocationFromD3DResource(d3d12_resource.get(), &dml_allocator_resource), ort_api);
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlFreeGPUAllocation(dml_allocator_resource), ort_api);
|
||||
THROW_IF_NOT_OK_MSG(ort_dml_api->CreateGPUAllocationFromD3DResource(d3d12_resource.get(), &dml_allocator_resource), ort_api);
|
||||
THROW_IF_NOT_OK_MSG(ort_dml_api->FreeGPUAllocation(dml_allocator_resource), ort_api);
|
||||
}
|
||||
|
||||
void DmlGetD3D12ResourceFromAllocation() {
|
||||
|
|
@ -156,68 +159,42 @@ void DmlGetD3D12ResourceFromAllocation() {
|
|||
|
||||
auto d3d12_resource = CreateD3D12Resource(*device);
|
||||
void* gpu_allocation;
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlCreateGPUAllocationFromD3DResource(d3d12_resource.get(), &gpu_allocation), ort_api);
|
||||
THROW_IF_NOT_OK_MSG(ort_dml_api->CreateGPUAllocationFromD3DResource(d3d12_resource.get(), &gpu_allocation), ort_api);
|
||||
|
||||
auto session = CreateDmlSession();
|
||||
OrtExecutionProvider* ort_provider;
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api);
|
||||
|
||||
OrtMemoryInfo* ort_memory_info;
|
||||
THROW_IF_NOT_OK_MSG(ort_api->CreateMemoryInfo("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault, &ort_memory_info), ort_api);
|
||||
|
||||
OrtAllocator* ort_allocator;
|
||||
THROW_IF_NOT_OK_MSG(ort_api->CreateAllocator(session.get(), ort_memory_info, &ort_allocator), ort_api);
|
||||
auto allocator = UniqueOrtAllocator(ort_allocator, ort_api->ReleaseAllocator);
|
||||
|
||||
winrt::com_ptr<ID3D12Resource> d3d12_resource_from_allocation;
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlGetD3D12ResourceFromAllocation(ort_provider, gpu_allocation, d3d12_resource_from_allocation.put()), ort_api);
|
||||
THROW_IF_NOT_OK_MSG(ort_dml_api->GetD3D12ResourceFromAllocation(allocator.get(), gpu_allocation, d3d12_resource_from_allocation.put()), ort_api);
|
||||
// Ensure resource is the same
|
||||
WINML_EXPECT_EQUAL(d3d12_resource, d3d12_resource_from_allocation);
|
||||
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlFreeGPUAllocation(gpu_allocation), ort_api);
|
||||
THROW_IF_NOT_OK_MSG(ort_dml_api->FreeGPUAllocation(gpu_allocation), ort_api);
|
||||
}
|
||||
|
||||
UniqueOrtValue CreateTensorFromMemoryInfo(OrtMemoryInfo* memory_info) {
|
||||
UniqueOrtValue CreateTensorFromMemoryInfo(const OrtMemoryInfo* memory_info) {
|
||||
OrtValue* tensor;
|
||||
THROW_IF_NOT_OK_MSG(ort_api->CreateTensorWithDataAsOrtValue(memory_info, tensor_values.data(), tensor_size * sizeof(float), dimensions.data(), dimensions.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &tensor), ort_api);
|
||||
return UniqueOrtValue(tensor, ort_api->ReleaseValue);
|
||||
}
|
||||
|
||||
void GetProviderMemoryInfo() {
|
||||
void GetTensorMemoryInfo() {
|
||||
GPUTEST;
|
||||
auto session = CreateDmlSession();
|
||||
OrtExecutionProvider* ort_provider;
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api);
|
||||
OrtMemoryInfo *memory_info;
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->GetProviderMemoryInfo(ort_provider, &memory_info), ort_api);
|
||||
auto unique_memory_info = UniqueOrtMemoryInfo(memory_info, ort_api->ReleaseMemoryInfo);
|
||||
// Ensure tensor can be created with the provided OrtMemoryInfo
|
||||
CreateTensorFromMemoryInfo(unique_memory_info.get());
|
||||
}
|
||||
|
||||
void GetAndFreeProviderAllocator() {
|
||||
GPUTEST;
|
||||
auto session = CreateDmlSession();
|
||||
OrtExecutionProvider* ort_provider;
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api);
|
||||
OrtAllocator *allocator;
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->GetProviderAllocator(ort_provider, &allocator), ort_api);
|
||||
auto unique_allocator = UniqueOrtAllocator(allocator, winml_adapter_api->FreeProviderAllocator);
|
||||
OrtMemoryInfo* ort_memory_info;
|
||||
THROW_IF_NOT_OK_MSG(ort_api->CreateMemoryInfo("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault, &ort_memory_info), ort_api);
|
||||
auto tensor = CreateTensorFromMemoryInfo(ort_memory_info);
|
||||
|
||||
// Ensure allocation works
|
||||
void *data = nullptr;
|
||||
THROW_IF_NOT_OK_MSG(ort_api->AllocatorAlloc(unique_allocator.get(), 1024, &data), ort_api);
|
||||
WINML_EXPECT_NOT_EQUAL(nullptr, data);
|
||||
THROW_IF_NOT_OK_MSG(ort_api->AllocatorFree(unique_allocator.get(), data), ort_api);
|
||||
}
|
||||
|
||||
void GetValueMemoryInfo() {
|
||||
GPUTEST;
|
||||
auto session = CreateDmlSession();
|
||||
OrtExecutionProvider* ort_provider;
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api);
|
||||
|
||||
OrtMemoryInfo *memory_info;
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->GetProviderMemoryInfo(ort_provider, &memory_info), ort_api);
|
||||
auto unique_memory_info = UniqueOrtMemoryInfo(memory_info, ort_api->ReleaseMemoryInfo);
|
||||
auto tensor = CreateTensorFromMemoryInfo(unique_memory_info.get());
|
||||
|
||||
OrtMemoryInfo *value_memory_info;
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->GetValueMemoryInfo(tensor.get(), &value_memory_info), ort_api);
|
||||
auto unique_value_memory_info = UniqueOrtMemoryInfo(value_memory_info, ort_api->ReleaseMemoryInfo);
|
||||
CreateTensorFromMemoryInfo(unique_value_memory_info.get());
|
||||
const OrtMemoryInfo* value_memory_info;
|
||||
THROW_IF_NOT_OK_MSG(ort_api->GetTensorMemoryInfo(tensor.get(), &value_memory_info), ort_api);
|
||||
CreateTensorFromMemoryInfo(value_memory_info);
|
||||
}
|
||||
|
||||
void ExecutionProviderSync() {
|
||||
|
|
@ -255,18 +232,17 @@ void DmlCopyTensor() {
|
|||
WINML_EXPECT_NOT_EQUAL(nullptr, winml_adapter_api->DmlCopyTensor(dml_provider, cpu_tensor.get(), dst_cpu_tensor.get()));
|
||||
|
||||
// GPU to CPU
|
||||
OrtMemoryInfo* dml_memory;
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->GetProviderMemoryInfo(dml_provider, &dml_memory), ort_api);
|
||||
auto unique_dml_memory = UniqueOrtMemoryInfo(dml_memory, ort_api->ReleaseMemoryInfo);
|
||||
OrtMemoryInfo* ort_memory_info;
|
||||
THROW_IF_NOT_OK_MSG(ort_api->CreateMemoryInfo("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault, &ort_memory_info), ort_api);
|
||||
|
||||
auto resource = CreateD3D12Resource(*device);
|
||||
void* dml_allocator_resource;
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlCreateGPUAllocationFromD3DResource(resource.get(), &dml_allocator_resource), ort_api);
|
||||
THROW_IF_NOT_OK_MSG(ort_dml_api->CreateGPUAllocationFromD3DResource(resource.get(), &dml_allocator_resource), ort_api);
|
||||
|
||||
std::array<int64_t, 3> shape = {720, 720, 3};
|
||||
OrtValue* gpu_value;
|
||||
THROW_IF_NOT_OK_MSG(ort_api->CreateTensorWithDataAsOrtValue(
|
||||
dml_memory,
|
||||
ort_memory_info,
|
||||
dml_allocator_resource,
|
||||
static_cast<size_t>(resource->GetDesc().Width),
|
||||
shape.data(),
|
||||
|
|
@ -277,7 +253,7 @@ void DmlCopyTensor() {
|
|||
dst_cpu_tensor = CreateTensorFromMemoryInfo(cpu_memory_info);
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlCopyTensor(dml_provider, gpu_value, dst_cpu_tensor.get()), ort_api);
|
||||
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlFreeGPUAllocation(dml_allocator_resource), ort_api);
|
||||
THROW_IF_NOT_OK_MSG(ort_dml_api->FreeGPUAllocation(dml_allocator_resource), ort_api);
|
||||
}
|
||||
|
||||
void CreateCustomRegistry() {
|
||||
|
|
@ -290,18 +266,17 @@ void CreateCustomRegistry() {
|
|||
void ValueGetDeviceId() {
|
||||
GPUTEST;
|
||||
auto session = CreateDmlSession();
|
||||
OrtExecutionProvider* ort_provider;
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api);
|
||||
OrtMemoryInfo *memory_info;
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->GetProviderMemoryInfo(ort_provider, &memory_info), ort_api);
|
||||
auto gpu_tensor = CreateTensorFromMemoryInfo(memory_info);
|
||||
|
||||
OrtMemoryInfo* ort_memory_info;
|
||||
THROW_IF_NOT_OK_MSG(ort_api->CreateMemoryInfo("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault, &ort_memory_info), ort_api);
|
||||
auto gpu_tensor = CreateTensorFromMemoryInfo(ort_memory_info);
|
||||
|
||||
int16_t device_id;
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->ValueGetDeviceId(gpu_tensor.get(), &device_id), ort_api);
|
||||
|
||||
OrtMemoryInfo* cpu_memory_info;
|
||||
THROW_IF_NOT_OK_MSG(ort_api->CreateCpuMemoryInfo(OrtDeviceAllocator, OrtMemTypeDefault, &cpu_memory_info), ort_api);
|
||||
auto unique_cpu_memory_info = UniqueOrtMemoryInfo(memory_info, ort_api->ReleaseMemoryInfo);
|
||||
auto unique_cpu_memory_info = UniqueOrtMemoryInfo(cpu_memory_info, ort_api->ReleaseMemoryInfo);
|
||||
auto cpu_tensor = CreateTensorFromMemoryInfo(unique_cpu_memory_info.get());
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->ValueGetDeviceId(cpu_tensor.get(), &device_id), ort_api);
|
||||
WINML_EXPECT_EQUAL(0, device_id);
|
||||
|
|
@ -329,9 +304,7 @@ const AdapterDmlEpTestApi& getapi() {
|
|||
DmlExecutionProviderReleaseCompletedReferences,
|
||||
DmlCreateAndFreeGPUAllocationFromD3DResource,
|
||||
DmlGetD3D12ResourceFromAllocation,
|
||||
GetProviderMemoryInfo,
|
||||
GetAndFreeProviderAllocator,
|
||||
GetValueMemoryInfo,
|
||||
GetTensorMemoryInfo,
|
||||
ExecutionProviderSync,
|
||||
DmlCopyTensor,
|
||||
CreateCustomRegistry,
|
||||
|
|
|
|||
|
|
@ -11,9 +11,7 @@ struct AdapterDmlEpTestApi
|
|||
VoidTest DmlExecutionProviderReleaseCompletedReferences;
|
||||
VoidTest DmlCreateGPUAllocationFromD3DResource;
|
||||
VoidTest DmlCreateAndFreeGPUAllocationFromD3DResource;
|
||||
VoidTest GetProviderMemoryInfo;
|
||||
VoidTest GetAndFreeProviderAllocator;
|
||||
VoidTest GetValueMemoryInfo;
|
||||
VoidTest GetTensorMemoryInfo;
|
||||
VoidTest ExecutionProviderSync;
|
||||
VoidTest DmlCopyTensor;
|
||||
VoidTest CreateCustomRegistry;
|
||||
|
|
@ -31,9 +29,7 @@ WINML_TEST(AdapterDmlEpTest, DmlExecutionProviderFlushContext)
|
|||
WINML_TEST(AdapterDmlEpTest, DmlExecutionProviderReleaseCompletedReferences)
|
||||
WINML_TEST(AdapterDmlEpTest, DmlCreateGPUAllocationFromD3DResource)
|
||||
WINML_TEST(AdapterDmlEpTest, DmlCreateAndFreeGPUAllocationFromD3DResource)
|
||||
WINML_TEST(AdapterDmlEpTest, GetProviderMemoryInfo)
|
||||
WINML_TEST(AdapterDmlEpTest, GetAndFreeProviderAllocator)
|
||||
WINML_TEST(AdapterDmlEpTest, GetValueMemoryInfo)
|
||||
WINML_TEST(AdapterDmlEpTest, GetTensorMemoryInfo)
|
||||
WINML_TEST(AdapterDmlEpTest, ExecutionProviderSync)
|
||||
WINML_TEST(AdapterDmlEpTest, DmlCopyTensor)
|
||||
WINML_TEST(AdapterDmlEpTest, CreateCustomRegistry)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
#include "core/common/logging/logging.h"
|
||||
#include "core/session/abi_session_options_impl.h"
|
||||
#include "core/session/ort_env.h"
|
||||
#include "core/providers/dml/dml_provider_factory.h"
|
||||
|
||||
using namespace _winml;
|
||||
using namespace winrt::Windows::Foundation::Collections;
|
||||
|
|
@ -27,8 +28,8 @@ using namespace winrt::Windows::Storage::Streams;
|
|||
|
||||
namespace {
|
||||
winrt::com_ptr<_winml::OnnxruntimeEngineFactory> engine_factory;
|
||||
const OrtApi *ort_api;
|
||||
const WinmlAdapterApi *winml_adapter_api;
|
||||
const OrtApi* ort_api;
|
||||
const WinmlAdapterApi* winml_adapter_api;
|
||||
OrtEnv* ort_env;
|
||||
|
||||
void AdapterSessionTestSetup() {
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ static void AdapterTestSetup() {
|
|||
#ifdef BUILD_INBOX
|
||||
winrt_activation_handler = WINRT_RoGetActivationFactory;
|
||||
#endif
|
||||
ort_api = OrtGetApiBase()->GetApi(2);
|
||||
winml_adapter_api = OrtGetWinMLAdapter(ort_api);
|
||||
ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
|
||||
winml_adapter_api = OrtGetWinMLAdapter(ORT_API_VERSION);
|
||||
|
||||
// for model tests
|
||||
std::wstring module_path = FileHelpers::GetModulePath();
|
||||
|
|
|
|||
Loading…
Reference in a new issue