Enable creating OrtValues from ID3D12Resources from the onnxruntime C-API (#9686)

* Add onnxruntime-windows api.

* minor fixes

* add to package headers

* Build ort_dml_api for provider extensions.

* Cleanup

* misc comment

* remove winml specific comments

* use dml check in onnxruntime

* Update include/onnxruntime/core/providers/dml/dml_provider_factory.h

Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>

* Update include/onnxruntime/core/session/onnxruntime_c_api.h

Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>

* Update include/onnxruntime/core/providers/dml/dml_provider_factory.h

Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>

* Update include/onnxruntime/core/providers/dml/dml_provider_factory.h

Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>

* Update onnxruntime/core/session/onnxruntime_c_api.cc

Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>

* Update onnxruntime/core/session/ort_apis.h

Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>

* Update winml/test/adapter/AdapterSessionTest.cpp

Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>

* Update onnxruntime/core/session/onnxruntime_c_api.cc

Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>

* Update winml/adapter/winml_adapter_c_api.cpp

Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>

* Update include/onnxruntime/core/session/onnxruntime_c_api.h

Co-authored-by: Pranav Sharma <prs@microsoft.com>

* Update onnxruntime/core/session/onnxruntime_c_api.cc

Co-authored-by: Pranav Sharma <prs@microsoft.com>

* Update winml/adapter/winml_adapter_c_api.cpp

* PR feedback

* Update include/onnxruntime/core/providers/dml/dml_provider_factory.h

Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>

* Update include/onnxruntime/core/providers/dml/dml_provider_factory.h

Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>

* Update include/onnxruntime/core/providers/dml/dml_provider_factory.h

Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>

* PR feedback

* merge resolution and unreference param

* (naming) Remove Dml prefix

* maybe unused version

* move DML code into DML path. CIs failing because DML is not available when --use_dml is not on

* fix warning causing local build failures after merging

* Change getvaluememoryinfo to gettensormemoryinfo

* minor breaks

* fix comment paste

* fix comment

Co-authored-by: Sheil Kumar <sheilk@microsoft.com>
Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>
Co-authored-by: Pranav Sharma <prs@microsoft.com>
This commit is contained in:
Sheil Kumar 2021-11-13 03:34:54 -08:00 committed by GitHub
parent 21eb747a0f
commit 3d0bd2596f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 295 additions and 203 deletions

View file

@ -37,6 +37,7 @@ namespace onnxruntime {
constexpr const char* CPU = "Cpu";
constexpr const char* CUDA = "Cuda";
constexpr const char* CUDA_PINNED = "CudaPinned";
constexpr const char* DML = "DML";
constexpr const char* MIGRAPHX = "MIGraphX";
constexpr const char* MIGRAPHX_PINNED = "MIGraphXPinned";

View file

@ -29,7 +29,11 @@ extern "C" {
* the adapter index. The device ID corresponds to the enumeration order of hardware adapters as given by
* IDXGIFactory::EnumAdapters. A device_id of 0 always corresponds to the default adapter, which is typically the
* primary display GPU installed on the system. A negative device_id is invalid.
*/
*
* [[deprecated]]
* This export should be deprecated.
* The OrtSessionOptionsAppendExecutionProvider_DML export on the OrtDmlApi should be used instead.
*/
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOptions* options, int device_id);
/**
@ -40,10 +44,58 @@ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOpti
* objects.
* See also: DMLCreateDevice
* See also: ID3D12Device::CreateCommandQueue
*
* [[deprecated]]
* This export should be deprecated.
* The OrtSessionOptionsAppendExecutionProvider_DML1 export on the OrtDmlApi should be used instead.
*/
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSessionOptions* options,
_In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue);
struct OrtDmlApi;
typedef struct OrtDmlApi OrtDmlApi;
struct OrtDmlApi {
/**
* Creates a DirectML Execution Provider which executes on the hardware adapter with the given device_id, also known as
* the adapter index. The device ID corresponds to the enumeration order of hardware adapters as given by
* IDXGIFactory::EnumAdapters. A device_id of 0 always corresponds to the default adapter, which is typically the
* primary display GPU installed on the system. A negative device_id is invalid.
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOptions* options, int device_id);
/**
* Creates a DirectML Execution Provider using the given DirectML device, and which executes work on the supplied D3D12
* command queue. The DirectML device and D3D12 command queue must have the same parent ID3D12Device, or an error will
* be returned. The D3D12 command queue must be of type DIRECT or COMPUTE (see D3D12_COMMAND_LIST_TYPE). If this
* function succeeds, the inference session maintains a strong reference on both the dml_device and the command_queue
* objects.
* See also: DMLCreateDevice
* See also: ID3D12Device::CreateCommandQueue
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML1, _In_ OrtSessionOptions* options,
_In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue);
/**
* CreateGPUAllocationFromD3DResource
* This api is used to create a DML EP input based on a user specified d3d12 resource.
*/
ORT_API2_STATUS(CreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* d3d_resource, _Out_ void** dml_resource);
/**
* FreeGPUAllocation
* This api is used free the DML EP input created by CreateGPUAllocationFromD3DResource.
*/
ORT_API2_STATUS(FreeGPUAllocation, _In_ void* dml_resource);
/**
* GetD3D12ResourceFromAllocation
* This api is used to get the D3D12 resource when a OrtValue has been allocated by the DML EP and accessed via GetMutableTensorData.
*/
ORT_API2_STATUS(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* provider, _In_ void* dml_allocation, _Out_ ID3D12Resource** d3d_resource);
};
#ifdef __cplusplus
}
#endif

View file

@ -1,9 +1,16 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// The winml "provider factory" is not a true execution provider.
// It is placed here as an execution provider to facilitate the export of the WinMLAdapter API
// via the OrtGetWinMLAdapter method.
#include "onnxruntime_c_api.h"
struct OrtWinApi;
typedef struct OrtWinApi OrtWinApi;
struct WinmlAdapterApi;
typedef struct WinmlAdapterApi WinmlAdapterApi;
ORT_EXPORT const WinmlAdapterApi* ORT_API_CALL OrtGetWinMLAdapter(_In_ const OrtApi* ort_api) NO_EXCEPTION;
ORT_EXPORT const WinmlAdapterApi* ORT_API_CALL OrtGetWinMLAdapter(_In_ uint32_t ort_api_version) NO_EXCEPTION;

View file

@ -3061,7 +3061,6 @@ struct OrtApi {
*/
ORT_API2_STATUS(HasValue, _In_ const OrtValue* value, _Out_ int* out);
/// @}
/// \name OrtKernelContext
/// @{
/** \brief Used for custom operators, gets the GPU compute stream to use to launch the custom a GPU kernel
@ -3074,6 +3073,33 @@ struct OrtApi {
* Only use it for custom kernel launching.
*/
ORT_API2_STATUS(KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, _Outptr_ void** out);
/// @}
/// \name GetTensorMemoryInfo
/// @{
/** \brief Returns a pointer to the ::OrtMemoryInfo of a Tensor
* \param[in] ort_value ::OrtValue containing tensor.
* \param[out] mem_info ::OrtMemoryInfo of the tensor. Do NOT free the returned pointer. It is valid for the lifetime of the ::OrtValue
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(GetTensorMemoryInfo, _In_ const OrtValue* value, _Out_ const OrtMemoryInfo** mem_info);
/// @}
/// \name GetExecutionProviderApi
/// @{
/** \brief Get a pointer to the requested version of the Execution Provider specific
* API extensions to the OrtApi
* \param[in] provider_name The name of the execution provider name. Currently only the following
* values are supported: "DML".
* \param[in] version Must be ::ORT_API_VERSION.
* \param[out] provider_api A void pointer containing a reference to the execution provider versioned api structure.
* For example, the provider_api pointer can be cast to the OrtDmlApi* when the provider_name is "DML".
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(GetExecutionProviderApi, _In_ const char* provider_name, _In_ uint32_t version, _Outptr_ const void** provider_api);
/// @}
/// \name SessionOptions

View file

@ -916,6 +916,8 @@ struct CustomOpApi {
template <typename T>
const T* GetTensorData(_Inout_ const OrtValue* value);
const OrtMemoryInfo* GetTensorMemoryInfo(_In_ const OrtValue* value);
std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info);
void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input);
size_t KernelContext_GetInputCount(const OrtKernelContext* context);

View file

@ -10,7 +10,7 @@ class DFT final : public OpKernel {
bool is_onesided_ = true;
public:
explicit DFT(const OpKernelInfo& info) : OpKernel(info) {
is_onesided_ = info.GetAttrOrDefault<int64_t>("onesided", 0);
is_onesided_ = static_cast<bool>(info.GetAttrOrDefault<int64_t>("onesided", 0));
}
Status Compute(OpKernelContext* ctx) const override;
};
@ -26,7 +26,7 @@ class STFT final : public OpKernel {
bool is_onesided_ = true;
public:
explicit STFT(const OpKernelInfo& info) : OpKernel(info) {
is_onesided_ = info.GetAttrOrDefault<int64_t>("onesided", 1);
is_onesided_ = static_cast<bool>(info.GetAttrOrDefault<int64_t>("onesided", 1));
}
Status Compute(OpKernelContext* ctx) const override;
};

View file

@ -67,6 +67,10 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA
*out = new OrtMemoryInfo(
onnxruntime::CUDA_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast<OrtDevice::DeviceId>(id1)),
id1, mem_type1);
} else if (strcmp(name1, onnxruntime::DML) == 0) {
*out = new OrtMemoryInfo(
onnxruntime::DML, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)),
id1, mem_type1);
} else {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified device is not supported.");
}

View file

@ -43,7 +43,7 @@ namespace Dml
)
: onnxruntime::IAllocator(
OrtMemoryInfo(
"DML allocator",
"DML",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)
)

View file

@ -12,6 +12,7 @@ using Microsoft::WRL::ComPtr;
#include "core/providers/dml/dml_provider_factory.h"
#include "core/session/abi_session_options_impl.h"
#include "core/session/allocator_adapters.h"
#include "core/session/ort_apis.h"
#include "core/framework/error_code_helper.h"
#include "DmlExecutionProvider/src/ErrorHandling.h"
@ -141,6 +142,9 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(in
} // namespace onnxruntime
// [[deprecated]]
// This export should be deprecated.
// The OrtSessionOptionsAppendExecutionProvider_DML export on the OrtDmlApi should be used instead.
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOptions* options, int device_id) {
API_IMPL_BEGIN
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_DML(device_id));
@ -148,6 +152,9 @@ API_IMPL_END
return nullptr;
}
// [[deprecated]]
// This export should be deprecated.
// The OrtSessionOptionsAppendExecutionProvider_DML1 export on the OrtDmlApi should be used instead.
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSessionOptions* options,
_In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue) {
API_IMPL_BEGIN
@ -156,3 +163,55 @@ API_IMPL_BEGIN
API_IMPL_END
return nullptr;
}
ORT_API_STATUS_IMPL(CreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* d3d_resource, _Out_ void** dml_resource) {
API_IMPL_BEGIN
#ifdef USE_DML
*dml_resource = Dml::CreateGPUAllocationFromD3DResource(d3d_resource);
#else
*dml_resource = nullptr;
#endif // USE_DML
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(FreeGPUAllocation, _In_ void* ptr) {
API_IMPL_BEGIN
#ifdef USE_DML
Dml::FreeGPUAllocation(ptr);
#endif // USE_DML
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* ort_allocator, _In_ void* allocation, _Out_ ID3D12Resource** d3d_resource) {
API_IMPL_BEGIN
#ifdef USE_DML
auto wrapping_allocator = static_cast<onnxruntime::OrtAllocatorImplWrappingIAllocator*>(ort_allocator);
auto allocator = wrapping_allocator->GetWrappedIAllocator();
if (!allocator) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available");
}
*d3d_resource = Dml::GetD3D12ResourceFromAllocation(allocator.get(), allocation);
(*d3d_resource)->AddRef();
#else
*d3d_resource = nullptr;
#endif // USE_DML
return nullptr;
API_IMPL_END
}
static constexpr OrtDmlApi ort_dml_api_10_to_x = {
&OrtSessionOptionsAppendExecutionProvider_DML,
&OrtSessionOptionsAppendExecutionProviderEx_DML,
&CreateGPUAllocationFromD3DResource,
&FreeGPUAllocation,
&GetD3D12ResourceFromAllocation
};
const OrtDmlApi* GetOrtDmlApi(_In_ uint32_t /*version*/) NO_EXCEPTION {
#ifdef USE_DML
return &ort_dml_api_10_to_x;
#endif // USE_DML
return nullptr;
}

View file

@ -31,6 +31,10 @@ const OrtMemoryInfo* OrtAllocatorImplWrappingIAllocator::Info() const {
return &i_allocator_->Info();
}
onnxruntime::AllocatorPtr OrtAllocatorImplWrappingIAllocator::GetWrappedIAllocator() {
return i_allocator_;
}
IAllocatorImplWrappingOrtAllocator::IAllocatorImplWrappingOrtAllocator(OrtAllocator* ort_allocator)
: IAllocator(*ort_allocator->Info(ort_allocator)), ort_allocator_(ort_allocator) {}

View file

@ -30,6 +30,8 @@ struct OrtAllocatorImplWrappingIAllocator final : public OrtAllocatorImpl {
ORT_DISALLOW_COPY_AND_ASSIGNMENT(OrtAllocatorImplWrappingIAllocator);
onnxruntime::AllocatorPtr GetWrappedIAllocator();
private:
onnxruntime::AllocatorPtr i_allocator_;
};

View file

@ -45,6 +45,11 @@ ProviderInfo_CUDA* TryGetProviderInfo_CUDA();
}
#endif
#ifdef USE_DML
#include "core/providers/dml/dml_provider_factory.h"
const OrtDmlApi* GetOrtDmlApi(_In_ uint32_t version) NO_EXCEPTION;
#endif
#ifdef ENABLE_EXTENSION_CUSTOM_OPS
#include "onnxruntime_extensions.h"
#endif
@ -1911,6 +1916,27 @@ ORT_API_STATUS_IMPL(OrtApis::ReleaseAvailableProviders, _In_ char** ptr,
return NULL;
}
ORT_API_STATUS_IMPL(OrtApis::GetExecutionProviderApi,
[[maybe_unused]] _In_ const char* provider_name,
[[maybe_unused]] _In_ uint32_t version,
_Outptr_ const void** provider_api) {
API_IMPL_BEGIN
*provider_api = nullptr;
#ifdef USE_DML
if (strcmp(provider_name, "DML") == 0) {
*provider_api = GetOrtDmlApi(version);
if (*provider_api == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified version is not supported for the DirectML provider.");
}
return NULL;
}
#endif
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified provider is not supported.");
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::TensorAt, _Inout_ OrtValue* value, const int64_t* location_values, size_t location_values_count,
_Outptr_ void** out) {
TENSOR_READWRITE_API_BEGIN
@ -2082,6 +2108,13 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArrayWithPrepackedWeightsContainer
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::GetTensorMemoryInfo, _In_ const OrtValue* value, _Outptr_ const OrtMemoryInfo** memory_info) {
TENSOR_READ_API_BEGIN
*memory_info = &tensor.Location();
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomCreateThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
API_IMPL_BEGIN
options->value.custom_create_thread_fn = ort_custom_create_thread_fn;
@ -2379,6 +2412,8 @@ static constexpr OrtApi ort_api_1_to_10 = {
// Version 10 - In development, feel free to add/remove/rearrange here
&OrtApis::HasValue,
&OrtApis::KernelContext_GetGPUComputeStream,
&OrtApis::GetTensorMemoryInfo,
&OrtApis::GetExecutionProviderApi,
&OrtApis::SessionOptionsSetCustomCreateThreadFn,
&OrtApis::SessionOptionsSetCustomThreadCreationOptions,
&OrtApis::SessionOptionsSetCustomJoinThreadFn,

View file

@ -317,6 +317,8 @@ ORT_API_STATUS_IMPL(GetSparseTensorValues, _In_ const OrtValue* ort_value, _Outp
ORT_API_STATUS_IMPL(GetSparseTensorIndicesTypeShape, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Outptr_ OrtTensorTypeAndShapeInfo** out);
ORT_API_STATUS_IMPL(GetSparseTensorIndices, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Out_ size_t* num_indices, _Outptr_ const void** indices);
ORT_API_STATUS_IMPL(KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, _Outptr_ void** out);
ORT_API_STATUS_IMPL(GetTensorMemoryInfo, _In_ const OrtValue* value, _Outptr_ const OrtMemoryInfo** memory_info);
ORT_API_STATUS_IMPL(GetExecutionProviderApi, _In_ const char* provider_name, _In_ uint32_t version, _Outptr_ const void** provider_api);
ORT_API_STATUS_IMPL(SessionOptionsSetCustomCreateThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn);
ORT_API_STATUS_IMPL(SessionOptionsSetCustomThreadCreationOptions, _Inout_ OrtSessionOptions* options, _In_ void* ort_custom_thread_creation_options);
ORT_API_STATUS_IMPL(SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn);

View file

@ -63,16 +63,12 @@ ORT_API_STATUS(SessionGetNamedDimensionsOverrides, _In_ OrtSession* session, _Ou
ORT_API_STATUS(DmlExecutionProviderSetDefaultRoundingMode, _In_ OrtExecutionProvider* dml_provider, _In_ bool is_enabled);
ORT_API_STATUS(DmlExecutionProviderFlushContext, _In_ OrtExecutionProvider* dml_provider);
ORT_API_STATUS(DmlExecutionProviderReleaseCompletedReferences, _In_ OrtExecutionProvider* dml_provider);
ORT_API_STATUS(DmlCreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* pResource, _Out_ void** dml_resource);
ORT_API_STATUS(DmlGetD3D12ResourceFromAllocation, _In_ OrtExecutionProvider* provider, _In_ void* allocation, _Out_ ID3D12Resource** resource);
ORT_API_STATUS(DmlFreeGPUAllocation, _In_ void* ptr);
// note: this returns a weak ref
ORT_API_STATUS(GetProviderMemoryInfo, _In_ OrtExecutionProvider* provider, OrtMemoryInfo** memory_info);
ORT_API_STATUS(GetProviderAllocator, _In_ OrtExecutionProvider* provider, OrtAllocator** allocator);
ORT_API_STATUS(FreeProviderAllocator, _In_ OrtAllocator* allocator);
ORT_API_STATUS(GetValueMemoryInfo, const OrtValue* value, OrtMemoryInfo** memory_info);
// ExecutionProvider Methods
ORT_API_STATUS(ExecutionProviderSync, _In_ OrtExecutionProvider* provider);

View file

@ -64,15 +64,11 @@ static constexpr WinmlAdapterApi winml_adapter_api_1 = {
&winmla::DmlExecutionProviderSetDefaultRoundingMode,
&winmla::DmlExecutionProviderFlushContext,
&winmla::DmlExecutionProviderReleaseCompletedReferences,
&winmla::DmlCreateGPUAllocationFromD3DResource,
&winmla::DmlFreeGPUAllocation,
&winmla::DmlGetD3D12ResourceFromAllocation,
&winmla::DmlCopyTensor,
&winmla::GetProviderMemoryInfo,
&winmla::GetProviderAllocator,
&winmla::FreeProviderAllocator,
&winmla::GetValueMemoryInfo,
&winmla::ExecutionProviderSync,
@ -100,10 +96,10 @@ static constexpr WinmlAdapterApi winml_adapter_api_1 = {
&winmla::ReleaseModel
};
const WinmlAdapterApi* ORT_API_CALL OrtGetWinMLAdapter(_In_ const OrtApi* ort_api) NO_EXCEPTION {
if (OrtApis::GetApi(2) == ort_api) {
const WinmlAdapterApi* ORT_API_CALL OrtGetWinMLAdapter(_In_ uint32_t ort_version) NO_EXCEPTION {
if (ort_version >= 2) {
return &winml_adapter_api_1;
}
return nullptr;
}
}

View file

@ -320,9 +320,9 @@ struct WinmlAdapterApi {
/**
* DmlExecutionProviderSetDefaultRoundingMode
* This api is used to configure the DML EP to turn on/off rounding.
* This api is used to configure the DML EP to turn on/off rounding.
*
* WinML uses this to disable rounding during session initialization and then enables it again post initialization.
* WinML uses this to disable rounding during session initialization and then enables it again post initialization.
*/
OrtStatus*(ORT_API_CALL* DmlExecutionProviderSetDefaultRoundingMode)(_In_ OrtExecutionProvider* dml_provider, _In_ bool is_enabled)NO_EXCEPTION;
@ -342,30 +342,6 @@ struct WinmlAdapterApi {
*/
OrtStatus*(ORT_API_CALL* DmlExecutionProviderReleaseCompletedReferences)(_In_ OrtExecutionProvider* dml_provider)NO_EXCEPTION;
/**
* DmlCreateGPUAllocationFromD3DResource
* This api is used to create a DML EP input based on a user specified d3d12 resource.
*
* WinML uses this as part of its Tensor apis to allow callers to specify their own D3D12 resources as inputs/outputs.
*/
OrtStatus*(ORT_API_CALL* DmlCreateGPUAllocationFromD3DResource)(_In_ ID3D12Resource* pResource, _Out_ void** dml_resource)NO_EXCEPTION;
/**
* DmlFreeGPUAllocation
* This api is used free the DML EP input created by DmlCreateGPUAllocationFromD3DResource.
*
* WinML uses this as part of its Tensor apis to allow callers to specify their own D3D12 resources as inputs/outputs.
*/
OrtStatus*(ORT_API_CALL* DmlFreeGPUAllocation)(_In_ void* ptr)NO_EXCEPTION;
/**
* DmlGetD3D12ResourceFromAllocation
* This api is used to get the D3D12 resource when a OrtValue has been allocated by the DML EP and accessed via GetMutableTensorData.
*
* WinML uses this in the image feature path to get the d3d resource and perform and tensorization on inputs directly into the allocated d3d12 resource.
*/
OrtStatus*(ORT_API_CALL* DmlGetD3D12ResourceFromAllocation)(_In_ OrtExecutionProvider* provider, _In_ void* allocation, _Out_ ID3D12Resource** resource)NO_EXCEPTION;
/**
* DmlCopyTensor
* This api is used copy a tensor allocated by the DML EP Allocator to the CPU.
@ -399,14 +375,6 @@ struct WinmlAdapterApi {
*/
OrtStatus*(ORT_API_CALL* FreeProviderAllocator)(_In_ OrtAllocator* allocator)NO_EXCEPTION;
/**
* GetValueMemoryInfo
* This api gets the memory info of an OrtValue.
*
* WinML uses this to determine if an OrtValue is allocated on the Cpu or elsewhere.
*/
OrtStatus*(ORT_API_CALL* GetValueMemoryInfo)(const OrtValue* value, OrtMemoryInfo** memory_info)NO_EXCEPTION;
/**
* ExecutionProviderSync
* This api syncs the EP.
@ -497,4 +465,4 @@ struct WinmlAdapterApi {
_In_ const char* const join_node_prefix)NO_EXCEPTION;
ORT_CLASS_RELEASE(Model);
};
};

View file

@ -126,42 +126,6 @@ ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderReleaseCompletedReferences, _In_
API_IMPL_END
}
ORT_API_STATUS_IMPL(winmla::DmlCreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* pResource, _Out_ void** dml_resource) {
API_IMPL_BEGIN
#ifdef USE_DML
*dml_resource = Dml::CreateGPUAllocationFromD3DResource(pResource);
#else
*dml_resource = nullptr;
#endif // USE_DML USE_DML
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(winmla::DmlGetD3D12ResourceFromAllocation, _In_ OrtExecutionProvider* dml_provider, _In_ void* allocation, _Out_ ID3D12Resource** d3d_resource) {
API_IMPL_BEGIN
#ifdef USE_DML
auto dml_provider_internal = reinterpret_cast<::onnxruntime::IExecutionProvider*>(dml_provider);
*d3d_resource =
Dml::GetD3D12ResourceFromAllocation(
dml_provider_internal->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault).get(),
allocation);
(*d3d_resource)->AddRef();
#else
*d3d_resource = nullptr;
#endif // USE_DML USE_DML
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(winmla::DmlFreeGPUAllocation, _In_ void* ptr) {
API_IMPL_BEGIN
#ifdef USE_DML
Dml::FreeGPUAllocation(ptr);
#endif // USE_DML USE_DML
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(winmla::DmlCopyTensor, _In_ OrtExecutionProvider* dml_provider, _In_ OrtValue* src, _In_ OrtValue* dst) {
API_IMPL_BEGIN
#ifdef USE_DML

View file

@ -76,16 +76,4 @@ ORT_API_STATUS_IMPL(winmla::FreeProviderAllocator, _In_ OrtAllocator* allocator)
delete static_cast<OrtAllocatorWrapper*>(allocator);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(winmla::GetValueMemoryInfo, const OrtValue* value, OrtMemoryInfo** memory_info) {
API_IMPL_BEGIN
const auto& tensor = value->Get<onnxruntime::Tensor>();
auto info = tensor.Location();
*memory_info = new OrtMemoryInfo(info.name, info.alloc_type, info.device, info.id, info.mem_type);
if (*memory_info == nullptr) {
return OrtApis::CreateStatus(ORT_FAIL, "Out of memory");
}
return nullptr;
API_IMPL_END
}
}

View file

@ -11,6 +11,8 @@
#include "OnnxruntimeSessionBuilder.h"
#include "OnnxruntimeErrors.h"
#include "core/providers/dml/dml_provider_factory.h"
using namespace _winml;
static ONNXTensorElementDataType
@ -89,19 +91,17 @@ HRESULT OnnxruntimeValue::IsEmpty(bool* out) {
HRESULT OnnxruntimeValue::IsCpu(bool* out) {
auto ort_api = engine_->GetEngineFactory()->UseOrtApi();
auto winml_adapter_api = engine_->GetEngineFactory()->UseWinmlAdapterApi();
OrtMemoryInfo* ort_memory_info;
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetValueMemoryInfo(value_.get(), &ort_memory_info),
const OrtMemoryInfo* ort_memory_info;
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMemoryInfo(value_.get(), &ort_memory_info),
ort_api);
auto memory_info = UniqueOrtMemoryInfo(ort_memory_info, ort_api->ReleaseMemoryInfo);
const char* name;
RETURN_HR_IF_NOT_OK_MSG(ort_api->MemoryInfoGetName(memory_info.get(), &name),
RETURN_HR_IF_NOT_OK_MSG(ort_api->MemoryInfoGetName(ort_memory_info, &name),
ort_api);
OrtMemType type;
RETURN_HR_IF_NOT_OK_MSG(ort_api->MemoryInfoGetMemType(memory_info.get(), &type),
RETURN_HR_IF_NOT_OK_MSG(ort_api->MemoryInfoGetMemType(ort_memory_info, &type),
ort_api);
*out = !strcmp(name, "Cpu") ||
@ -167,20 +167,28 @@ static auto GetStrings(const OrtApi* ort_api, const OrtValue* ort_value,
HRESULT OnnxruntimeValue::GetResource(_winml::Resource& out) {
auto ort_api = engine_->GetEngineFactory()->UseOrtApi();
auto winml_adapter_api = engine_->GetEngineFactory()->UseWinmlAdapterApi();
void* mutable_data = nullptr;
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(value_.get(), &mutable_data),
ort_api);
OrtExecutionProvider* ort_provider;
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(engine_->UseOrtSession(), 0, &ort_provider),
const OrtMemoryInfo* ort_memory_info;
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMemoryInfo(value_.get(), &ort_memory_info),
ort_api);
bool is_cpu = false;
if (SUCCEEDED(IsCpu(&is_cpu)) && !is_cpu) {
const OrtDmlApi* ort_dml_api;
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&ort_dml_api)),
ort_api);
OrtAllocator* ort_allocator;
RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateAllocator(engine_->UseOrtSession(), ort_memory_info, &ort_allocator),
ort_api);
auto allocator = UniqueOrtAllocator(ort_allocator, ort_api->ReleaseAllocator);
winrt::com_ptr<ID3D12Resource> resource;
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlGetD3D12ResourceFromAllocation(ort_provider, mutable_data,
RETURN_HR_IF_NOT_OK_MSG(ort_dml_api->GetD3D12ResourceFromAllocation(allocator.get(), mutable_data,
resource.put()),
ort_api);
out = _winml::Resource(resource.get(), [](void*) { /*do nothing, as this pointer is actually a com pointer! */ });
@ -578,7 +586,11 @@ HRESULT OnnxruntimeEngine::CreateTensorValue(const int64_t* shape, size_t count,
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetProviderAllocator(ort_provider, &ort_allocator),
engine_factory_->UseOrtApi());
auto unique_allocator = UniqueOrtAllocator(ort_allocator, winml_adapter_api->FreeProviderAllocator); // the release here should probably not return anything
auto unique_allocator = UniqueOrtAllocator(
ort_allocator,
[](OrtAllocator* allocator) {
GetVersionedWinmlAdapterApi()->FreeProviderAllocator(allocator);
}); // the release here should probably not return anything
OrtValue* ort_value;
RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateTensorAsOrtValue(unique_allocator.get(), shape, count, ONNXTensorElementDataTypeFromTensorKind(kind), &ort_value),
@ -612,30 +624,35 @@ class DmlAllocatorWrapper : public Microsoft::WRL::RuntimeClass<
*/
HRESULT OnnxruntimeEngine::CreateTensorValueFromExternalD3DResource(ID3D12Resource* d3d_resource, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) {
auto ort_api = engine_factory_->UseOrtApi();
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();
const OrtDmlApi* ort_dml_api;
RETURN_HR_IF_NOT_OK_MSG(ort_api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&ort_dml_api)),
ort_api);
OrtExecutionProvider* ort_provider;
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider),
engine_factory_->UseOrtApi());
OrtMemoryInfo* ort_memory_info;
RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateMemoryInfo("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault, &ort_memory_info),
ort_api);
OrtMemoryInfo* dml_memory = nullptr;
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetProviderMemoryInfo(ort_provider, &dml_memory),
engine_factory_->UseOrtApi());
OrtAllocator* ort_allocator;
RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateAllocator(session_.get(), ort_memory_info, &ort_allocator),
ort_api);
auto allocator = UniqueOrtAllocator(ort_allocator, ort_api->ReleaseAllocator);
void* dml_allocator_resource;
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlCreateGPUAllocationFromD3DResource(d3d_resource, &dml_allocator_resource),
RETURN_HR_IF_NOT_OK_MSG(ort_dml_api->CreateGPUAllocationFromD3DResource(d3d_resource, &dml_allocator_resource),
engine_factory_->UseOrtApi());
auto unique_dml_allocator_resource =
DmlAllocatorResource(dml_allocator_resource,
[](void* ptr) {
GetVersionedWinmlAdapterApi()->DmlFreeGPUAllocation(ptr);
const OrtDmlApi* ort_dml_api;
GetVersionedOrtApi()->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&ort_dml_api));
ort_dml_api->FreeGPUAllocation(ptr);
});
// create the OrtValue as a tensor letting ort know that we own the data buffer
OrtValue* ort_value;
RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateTensorWithDataAsOrtValue(
dml_memory,
ort_memory_info,
unique_dml_allocator_resource.get(),
static_cast<size_t>(d3d_resource->GetDesc().Width),
shape,

View file

@ -58,9 +58,7 @@ const OrtApi* _winml::GetVersionedOrtApi() {
}
const auto ort_api_base = ort_get_api_base_fn();
static const uint32_t ort_version = 2;
return ort_api_base->GetApi(ort_version);
return ort_api_base->GetApi(ORT_API_VERSION);
}
static const WinmlAdapterApi* GetVersionedWinmlAdapterApi(const OrtApi* ort_api) {
@ -73,7 +71,7 @@ static const WinmlAdapterApi* GetVersionedWinmlAdapterApi(const OrtApi* ort_api)
FAIL_FAST_HR(HRESULT_FROM_WIN32(GetLastError()));
}
return ort_get_winml_adapter_fn(ort_api);
return ort_get_winml_adapter_fn(ORT_API_VERSION);
}
const WinmlAdapterApi* _winml::GetVersionedWinmlAdapterApi() {

View file

@ -8,7 +8,7 @@
#include "adapter/winml_adapter_c_api.h"
using UniqueOrtModel = std::unique_ptr<OrtModel, decltype(WinmlAdapterApi::ReleaseModel)>;
using UniqueOrtAllocator = std::unique_ptr<OrtAllocator, decltype(WinmlAdapterApi::FreeProviderAllocator)>;
using UniqueOrtAllocator = std::unique_ptr<OrtAllocator, decltype(OrtApi::ReleaseAllocator)>;
using UniqueOrtSessionOptions = std::unique_ptr<OrtSessionOptions, decltype(OrtApi::ReleaseSessionOptions)>;
using UniqueOrtSession = std::unique_ptr<OrtSession, decltype(OrtApi::ReleaseSession)>;
using UniqueOrtValue = std::unique_ptr<OrtValue, decltype(OrtApi::ReleaseValue)>;
@ -17,3 +17,4 @@ using UniqueOrtTypeInfo = std::unique_ptr<OrtTypeInfo,
using UniqueOrtTensorTypeAndShapeInfo = std::unique_ptr<OrtTensorTypeAndShapeInfo, decltype(OrtApi::ReleaseTensorTypeAndShapeInfo)>;
using UniqueOrtRunOptions = std::unique_ptr<OrtRunOptions, decltype(OrtApi::ReleaseRunOptions)>;
using UniqueOrtEnv = std::unique_ptr<OrtEnv, decltype(OrtApi::ReleaseEnv)>;

View file

@ -10,6 +10,7 @@
#include "winml_adapter_c_api.h"
#include "core/framework/execution_provider.h"
#include "core/providers/winml/winml_provider_factory.h"
#include "core/providers/dml/dml_provider_factory.h"
#include "OnnxruntimeEngine.h"
#include "OnnxruntimeErrors.h"
#include "OnnxruntimeModel.h"
@ -17,15 +18,17 @@
namespace {
const WinmlAdapterApi* winml_adapter_api;
const OrtDmlApi* ort_dml_api;
const OrtApi* ort_api;
OrtEnv* ort_env;
void AdapterDmlEpTestSetup() {
GPUTEST;
winrt::init_apartment();
ort_api = OrtGetApiBase()->GetApi(2);
winml_adapter_api = OrtGetWinMLAdapter(ort_api);
ort_api->CreateEnv(OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env);
ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
THROW_IF_NOT_OK_MSG(ort_api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&ort_dml_api)), ort_api);
winml_adapter_api = OrtGetWinMLAdapter(ORT_API_VERSION);
THROW_IF_NOT_OK_MSG(ort_api->CreateEnv(OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env), ort_api);
#ifdef BUILD_INBOX
winrt_activation_handler = WINRT_RoGetActivationFactory;
#endif
@ -145,8 +148,8 @@ void DmlCreateAndFreeGPUAllocationFromD3DResource() {
auto d3d12_resource = CreateD3D12Resource(*device);
void* dml_allocator_resource;
THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlCreateGPUAllocationFromD3DResource(d3d12_resource.get(), &dml_allocator_resource), ort_api);
THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlFreeGPUAllocation(dml_allocator_resource), ort_api);
THROW_IF_NOT_OK_MSG(ort_dml_api->CreateGPUAllocationFromD3DResource(d3d12_resource.get(), &dml_allocator_resource), ort_api);
THROW_IF_NOT_OK_MSG(ort_dml_api->FreeGPUAllocation(dml_allocator_resource), ort_api);
}
void DmlGetD3D12ResourceFromAllocation() {
@ -156,68 +159,42 @@ void DmlGetD3D12ResourceFromAllocation() {
auto d3d12_resource = CreateD3D12Resource(*device);
void* gpu_allocation;
THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlCreateGPUAllocationFromD3DResource(d3d12_resource.get(), &gpu_allocation), ort_api);
THROW_IF_NOT_OK_MSG(ort_dml_api->CreateGPUAllocationFromD3DResource(d3d12_resource.get(), &gpu_allocation), ort_api);
auto session = CreateDmlSession();
OrtExecutionProvider* ort_provider;
THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api);
OrtMemoryInfo* ort_memory_info;
THROW_IF_NOT_OK_MSG(ort_api->CreateMemoryInfo("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault, &ort_memory_info), ort_api);
OrtAllocator* ort_allocator;
THROW_IF_NOT_OK_MSG(ort_api->CreateAllocator(session.get(), ort_memory_info, &ort_allocator), ort_api);
auto allocator = UniqueOrtAllocator(ort_allocator, ort_api->ReleaseAllocator);
winrt::com_ptr<ID3D12Resource> d3d12_resource_from_allocation;
THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlGetD3D12ResourceFromAllocation(ort_provider, gpu_allocation, d3d12_resource_from_allocation.put()), ort_api);
THROW_IF_NOT_OK_MSG(ort_dml_api->GetD3D12ResourceFromAllocation(allocator.get(), gpu_allocation, d3d12_resource_from_allocation.put()), ort_api);
// Ensure resource is the same
WINML_EXPECT_EQUAL(d3d12_resource, d3d12_resource_from_allocation);
THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlFreeGPUAllocation(gpu_allocation), ort_api);
THROW_IF_NOT_OK_MSG(ort_dml_api->FreeGPUAllocation(gpu_allocation), ort_api);
}
UniqueOrtValue CreateTensorFromMemoryInfo(OrtMemoryInfo* memory_info) {
UniqueOrtValue CreateTensorFromMemoryInfo(const OrtMemoryInfo* memory_info) {
OrtValue* tensor;
THROW_IF_NOT_OK_MSG(ort_api->CreateTensorWithDataAsOrtValue(memory_info, tensor_values.data(), tensor_size * sizeof(float), dimensions.data(), dimensions.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &tensor), ort_api);
return UniqueOrtValue(tensor, ort_api->ReleaseValue);
}
void GetProviderMemoryInfo() {
void GetTensorMemoryInfo() {
GPUTEST;
auto session = CreateDmlSession();
OrtExecutionProvider* ort_provider;
THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api);
OrtMemoryInfo *memory_info;
THROW_IF_NOT_OK_MSG(winml_adapter_api->GetProviderMemoryInfo(ort_provider, &memory_info), ort_api);
auto unique_memory_info = UniqueOrtMemoryInfo(memory_info, ort_api->ReleaseMemoryInfo);
// Ensure tensor can be created with the provided OrtMemoryInfo
CreateTensorFromMemoryInfo(unique_memory_info.get());
}
void GetAndFreeProviderAllocator() {
GPUTEST;
auto session = CreateDmlSession();
OrtExecutionProvider* ort_provider;
THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api);
OrtAllocator *allocator;
THROW_IF_NOT_OK_MSG(winml_adapter_api->GetProviderAllocator(ort_provider, &allocator), ort_api);
auto unique_allocator = UniqueOrtAllocator(allocator, winml_adapter_api->FreeProviderAllocator);
OrtMemoryInfo* ort_memory_info;
THROW_IF_NOT_OK_MSG(ort_api->CreateMemoryInfo("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault, &ort_memory_info), ort_api);
auto tensor = CreateTensorFromMemoryInfo(ort_memory_info);
// Ensure allocation works
void *data = nullptr;
THROW_IF_NOT_OK_MSG(ort_api->AllocatorAlloc(unique_allocator.get(), 1024, &data), ort_api);
WINML_EXPECT_NOT_EQUAL(nullptr, data);
THROW_IF_NOT_OK_MSG(ort_api->AllocatorFree(unique_allocator.get(), data), ort_api);
}
void GetValueMemoryInfo() {
GPUTEST;
auto session = CreateDmlSession();
OrtExecutionProvider* ort_provider;
THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api);
OrtMemoryInfo *memory_info;
THROW_IF_NOT_OK_MSG(winml_adapter_api->GetProviderMemoryInfo(ort_provider, &memory_info), ort_api);
auto unique_memory_info = UniqueOrtMemoryInfo(memory_info, ort_api->ReleaseMemoryInfo);
auto tensor = CreateTensorFromMemoryInfo(unique_memory_info.get());
OrtMemoryInfo *value_memory_info;
THROW_IF_NOT_OK_MSG(winml_adapter_api->GetValueMemoryInfo(tensor.get(), &value_memory_info), ort_api);
auto unique_value_memory_info = UniqueOrtMemoryInfo(value_memory_info, ort_api->ReleaseMemoryInfo);
CreateTensorFromMemoryInfo(unique_value_memory_info.get());
const OrtMemoryInfo* value_memory_info;
THROW_IF_NOT_OK_MSG(ort_api->GetTensorMemoryInfo(tensor.get(), &value_memory_info), ort_api);
CreateTensorFromMemoryInfo(value_memory_info);
}
void ExecutionProviderSync() {
@ -255,18 +232,17 @@ void DmlCopyTensor() {
WINML_EXPECT_NOT_EQUAL(nullptr, winml_adapter_api->DmlCopyTensor(dml_provider, cpu_tensor.get(), dst_cpu_tensor.get()));
// GPU to CPU
OrtMemoryInfo* dml_memory;
THROW_IF_NOT_OK_MSG(winml_adapter_api->GetProviderMemoryInfo(dml_provider, &dml_memory), ort_api);
auto unique_dml_memory = UniqueOrtMemoryInfo(dml_memory, ort_api->ReleaseMemoryInfo);
OrtMemoryInfo* ort_memory_info;
THROW_IF_NOT_OK_MSG(ort_api->CreateMemoryInfo("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault, &ort_memory_info), ort_api);
auto resource = CreateD3D12Resource(*device);
void* dml_allocator_resource;
THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlCreateGPUAllocationFromD3DResource(resource.get(), &dml_allocator_resource), ort_api);
THROW_IF_NOT_OK_MSG(ort_dml_api->CreateGPUAllocationFromD3DResource(resource.get(), &dml_allocator_resource), ort_api);
std::array<int64_t, 3> shape = {720, 720, 3};
OrtValue* gpu_value;
THROW_IF_NOT_OK_MSG(ort_api->CreateTensorWithDataAsOrtValue(
dml_memory,
ort_memory_info,
dml_allocator_resource,
static_cast<size_t>(resource->GetDesc().Width),
shape.data(),
@ -277,7 +253,7 @@ void DmlCopyTensor() {
dst_cpu_tensor = CreateTensorFromMemoryInfo(cpu_memory_info);
THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlCopyTensor(dml_provider, gpu_value, dst_cpu_tensor.get()), ort_api);
THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlFreeGPUAllocation(dml_allocator_resource), ort_api);
THROW_IF_NOT_OK_MSG(ort_dml_api->FreeGPUAllocation(dml_allocator_resource), ort_api);
}
void CreateCustomRegistry() {
@ -290,18 +266,17 @@ void CreateCustomRegistry() {
void ValueGetDeviceId() {
GPUTEST;
auto session = CreateDmlSession();
OrtExecutionProvider* ort_provider;
THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api);
OrtMemoryInfo *memory_info;
THROW_IF_NOT_OK_MSG(winml_adapter_api->GetProviderMemoryInfo(ort_provider, &memory_info), ort_api);
auto gpu_tensor = CreateTensorFromMemoryInfo(memory_info);
OrtMemoryInfo* ort_memory_info;
THROW_IF_NOT_OK_MSG(ort_api->CreateMemoryInfo("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault, &ort_memory_info), ort_api);
auto gpu_tensor = CreateTensorFromMemoryInfo(ort_memory_info);
int16_t device_id;
THROW_IF_NOT_OK_MSG(winml_adapter_api->ValueGetDeviceId(gpu_tensor.get(), &device_id), ort_api);
OrtMemoryInfo* cpu_memory_info;
THROW_IF_NOT_OK_MSG(ort_api->CreateCpuMemoryInfo(OrtDeviceAllocator, OrtMemTypeDefault, &cpu_memory_info), ort_api);
auto unique_cpu_memory_info = UniqueOrtMemoryInfo(memory_info, ort_api->ReleaseMemoryInfo);
auto unique_cpu_memory_info = UniqueOrtMemoryInfo(cpu_memory_info, ort_api->ReleaseMemoryInfo);
auto cpu_tensor = CreateTensorFromMemoryInfo(unique_cpu_memory_info.get());
THROW_IF_NOT_OK_MSG(winml_adapter_api->ValueGetDeviceId(cpu_tensor.get(), &device_id), ort_api);
WINML_EXPECT_EQUAL(0, device_id);
@ -329,9 +304,7 @@ const AdapterDmlEpTestApi& getapi() {
DmlExecutionProviderReleaseCompletedReferences,
DmlCreateAndFreeGPUAllocationFromD3DResource,
DmlGetD3D12ResourceFromAllocation,
GetProviderMemoryInfo,
GetAndFreeProviderAllocator,
GetValueMemoryInfo,
GetTensorMemoryInfo,
ExecutionProviderSync,
DmlCopyTensor,
CreateCustomRegistry,

View file

@ -11,9 +11,7 @@ struct AdapterDmlEpTestApi
VoidTest DmlExecutionProviderReleaseCompletedReferences;
VoidTest DmlCreateGPUAllocationFromD3DResource;
VoidTest DmlCreateAndFreeGPUAllocationFromD3DResource;
VoidTest GetProviderMemoryInfo;
VoidTest GetAndFreeProviderAllocator;
VoidTest GetValueMemoryInfo;
VoidTest GetTensorMemoryInfo;
VoidTest ExecutionProviderSync;
VoidTest DmlCopyTensor;
VoidTest CreateCustomRegistry;
@ -31,9 +29,7 @@ WINML_TEST(AdapterDmlEpTest, DmlExecutionProviderFlushContext)
WINML_TEST(AdapterDmlEpTest, DmlExecutionProviderReleaseCompletedReferences)
WINML_TEST(AdapterDmlEpTest, DmlCreateGPUAllocationFromD3DResource)
WINML_TEST(AdapterDmlEpTest, DmlCreateAndFreeGPUAllocationFromD3DResource)
WINML_TEST(AdapterDmlEpTest, GetProviderMemoryInfo)
WINML_TEST(AdapterDmlEpTest, GetAndFreeProviderAllocator)
WINML_TEST(AdapterDmlEpTest, GetValueMemoryInfo)
WINML_TEST(AdapterDmlEpTest, GetTensorMemoryInfo)
WINML_TEST(AdapterDmlEpTest, ExecutionProviderSync)
WINML_TEST(AdapterDmlEpTest, DmlCopyTensor)
WINML_TEST(AdapterDmlEpTest, CreateCustomRegistry)

View file

@ -17,6 +17,7 @@
#include "core/common/logging/logging.h"
#include "core/session/abi_session_options_impl.h"
#include "core/session/ort_env.h"
#include "core/providers/dml/dml_provider_factory.h"
using namespace _winml;
using namespace winrt::Windows::Foundation::Collections;
@ -27,8 +28,8 @@ using namespace winrt::Windows::Storage::Streams;
namespace {
winrt::com_ptr<_winml::OnnxruntimeEngineFactory> engine_factory;
const OrtApi *ort_api;
const WinmlAdapterApi *winml_adapter_api;
const OrtApi* ort_api;
const WinmlAdapterApi* winml_adapter_api;
OrtEnv* ort_env;
void AdapterSessionTestSetup() {

View file

@ -11,8 +11,8 @@ static void AdapterTestSetup() {
#ifdef BUILD_INBOX
winrt_activation_handler = WINRT_RoGetActivationFactory;
#endif
ort_api = OrtGetApiBase()->GetApi(2);
winml_adapter_api = OrtGetWinMLAdapter(ort_api);
ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
winml_adapter_api = OrtGetWinMLAdapter(ORT_API_VERSION);
// for model tests
std::wstring module_path = FileHelpers::GetModulePath();