mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
[DML EP] Complete python IO binding implementation (#17344)
@fdwr This is the part 2 of the pybind work that was started earlier. This adds the following features to the python IO binding implementation: - Use a bucketized allocator in order to reduce the number of resource allocations - Implement the following functions: `ortvalue_from_numpy`, `update_inplace`, `ortvalue_from_shape_and_type` and `numpy` - Modify the `onnxruntime_test_python_iobinding` tests to also run on DML --------- Co-authored-by: Jeff Bloomfield <jeffbloo@microsoft.com>
This commit is contained in:
parent
c0a4fe777f
commit
54a092c427
10 changed files with 402 additions and 243 deletions
|
|
@ -212,15 +212,6 @@ namespace Dml
|
|||
ORT_THROW_HR(E_INVALIDARG);
|
||||
}
|
||||
const auto* allocInfo = static_cast<const AllocationInfo*>(opaqueHandle);
|
||||
|
||||
auto owner = allocInfo->GetOwner();
|
||||
//The owner can be null if the resource was wrapped via CreateGPUAllocationFromD3DResource
|
||||
if (owner != nullptr && owner != this)
|
||||
{
|
||||
// This allocation doesn't belong to this allocator!
|
||||
ORT_THROW_HR(E_INVALIDARG);
|
||||
}
|
||||
|
||||
return allocInfo;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -83,16 +83,16 @@ namespace Dml
|
|||
std::vector<Bucket> m_pool;
|
||||
size_t m_currentAllocationId = 0;
|
||||
uint64_t m_currentResourceId = 0;
|
||||
|
||||
// Unless specifically requested, allocation sizes are not rounded to enable pooling
|
||||
// until SetDefaultRoundingMode is called. This should be done at completion of session
|
||||
|
||||
// Unless specifically requested, allocation sizes are not rounded to enable pooling
|
||||
// until SetDefaultRoundingMode is called. This should be done at completion of session
|
||||
// initialization.
|
||||
AllocatorRoundingMode m_defaultRoundingMode = AllocatorRoundingMode::Disabled;
|
||||
|
||||
std::shared_ptr<ExecutionContext> m_context;
|
||||
std::unique_ptr<DmlSubAllocator> m_subAllocator;
|
||||
|
||||
#if _DEBUG
|
||||
#ifndef NDEBUG
|
||||
// Useful for debugging; keeps track of all allocations that haven't been freed yet
|
||||
std::map<size_t, AllocationInfo*> m_outstandingAllocationsById;
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -128,6 +128,31 @@ Microsoft::WRL::ComPtr<ID3D12Device> DMLProviderFactoryCreator::CreateD3D12Devic
|
|||
return d3d12_device;
|
||||
}
|
||||
|
||||
Microsoft::WRL::ComPtr<IDMLDevice> DMLProviderFactoryCreator::CreateDMLDevice(ID3D12Device* d3d12_device)
|
||||
{
|
||||
DML_CREATE_DEVICE_FLAGS flags = DML_CREATE_DEVICE_FLAG_NONE;
|
||||
|
||||
// In debug builds, enable the DML debug layer if the D3D12 debug layer is also enabled
|
||||
#if _DEBUG && !_GAMING_XBOX
|
||||
Microsoft::WRL::ComPtr<ID3D12DebugDevice> debug_device;
|
||||
(void)d3d12_device->QueryInterface(IID_PPV_ARGS(&debug_device)); // ignore failure
|
||||
const bool is_d3d12_debug_layer_enabled = (debug_device != nullptr);
|
||||
|
||||
if (is_d3d12_debug_layer_enabled) {
|
||||
flags |= DML_CREATE_DEVICE_FLAG_DEBUG;
|
||||
}
|
||||
#endif
|
||||
|
||||
Microsoft::WRL::ComPtr<IDMLDevice> dml_device;
|
||||
ORT_THROW_IF_FAILED(DMLCreateDevice1(
|
||||
d3d12_device,
|
||||
flags,
|
||||
DML_FEATURE_LEVEL_5_0,
|
||||
IID_PPV_ARGS(&dml_device)));
|
||||
|
||||
return dml_device;
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::Create(int device_id, bool skip_software_device_check) {
|
||||
ComPtr<ID3D12Device> d3d12_device = CreateD3D12Device(device_id, skip_software_device_check);
|
||||
|
||||
|
|
@ -138,25 +163,7 @@ std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::Create(int
|
|||
ComPtr<ID3D12CommandQueue> cmd_queue;
|
||||
ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_GRAPHICS_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf())));
|
||||
|
||||
DML_CREATE_DEVICE_FLAGS flags = DML_CREATE_DEVICE_FLAG_NONE;
|
||||
|
||||
// In debug builds, enable the DML debug layer if the D3D12 debug layer is also enabled
|
||||
#if _DEBUG && !_GAMING_XBOX
|
||||
ComPtr<ID3D12DebugDevice> debug_device;
|
||||
(void)d3d12_device->QueryInterface(IID_PPV_ARGS(&debug_device)); // ignore failure
|
||||
const bool is_d3d12_debug_layer_enabled = (debug_device != nullptr);
|
||||
|
||||
if (is_d3d12_debug_layer_enabled) {
|
||||
flags |= DML_CREATE_DEVICE_FLAG_DEBUG;
|
||||
}
|
||||
#endif
|
||||
|
||||
ComPtr<IDMLDevice> dml_device;
|
||||
ORT_THROW_IF_FAILED(DMLCreateDevice1(d3d12_device.Get(),
|
||||
flags,
|
||||
DML_FEATURE_LEVEL_5_0,
|
||||
IID_PPV_ARGS(&dml_device)));
|
||||
|
||||
auto dml_device = CreateDMLDevice(d3d12_device.Get());
|
||||
return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,5 +16,6 @@ struct DMLProviderFactoryCreator {
|
|||
static std::shared_ptr<IExecutionProviderFactory> Create(int device_id);
|
||||
static std::shared_ptr<IExecutionProviderFactory> Create(int device_id, bool skip_software_device_check);
|
||||
static Microsoft::WRL::ComPtr<ID3D12Device> CreateD3D12Device(int device_id, bool skip_software_device_check);
|
||||
static Microsoft::WRL::ComPtr<IDMLDevice> CreateDMLDevice(ID3D12Device* d3d12_device);
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -26,7 +26,18 @@
|
|||
#include "core/framework/provider_options_utils.h"
|
||||
|
||||
#ifdef USE_DML
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h"
|
||||
using Microsoft::WRL::ComPtr;
|
||||
|
||||
#include <wil/wrl.h>
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/External/D3DX12/d3dx12.h"
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/ErrorHandling.h"
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/DescriptorPool.h"
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.h"
|
||||
#include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h"
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h"
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.h"
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/ReadbackHeap.h"
|
||||
#include "core/providers/dml/DmlExecutionProvider/src/AllocationInfo.h"
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -186,6 +197,11 @@ std::unique_ptr<IDataTransfer> GetGPUDataTransfer() {
|
|||
#endif
|
||||
|
||||
#ifdef USE_DML
|
||||
|
||||
constexpr GUID execution_context_guid = {0x50fd773b, 0x4462, 0x4b28, {0x98, 0x9e, 0x8c, 0xa0, 0x54, 0x05, 0xbd, 0x4a}};
|
||||
constexpr GUID upload_heap_guid = {0x125235f9, 0xef41, 0x4043, {0xa4, 0x9d, 0xdd, 0xc9, 0x61, 0xe7, 0xdb, 0xee}};
|
||||
constexpr GUID dml_readback_heap_guid = {0x00d32df8, 0xea2d, 0x40bf, {0xa4, 0x47, 0x9c, 0xb4, 0xbc, 0xf1, 0x1d, 0x5e}};
|
||||
|
||||
AllocatorPtr GetDmlAllocator(OrtDevice::DeviceId id) {
|
||||
// Current approach is not thread-safe, but there are some bigger infra pieces to put together in order to make
|
||||
// multi-threaded DML allocation work, including maintaining a per-thread DML allocator.
|
||||
|
|
@ -196,13 +212,100 @@ AllocatorPtr GetDmlAllocator(OrtDevice::DeviceId id) {
|
|||
|
||||
auto hit = id_to_allocator_map->find(id);
|
||||
if (hit == id_to_allocator_map->end()) {
|
||||
auto dml_allocator = std::make_shared<Dml::DmlExternalBufferAllocator>(id);
|
||||
constexpr uint32_t device_id = 0;
|
||||
auto d3d12_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false);
|
||||
auto dml_device = onnxruntime::DMLProviderFactoryCreator::CreateDMLDevice(d3d12_device.Get());
|
||||
|
||||
D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {};
|
||||
cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
|
||||
cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT;
|
||||
|
||||
ComPtr<ID3D12CommandQueue> cmd_queue;
|
||||
ORT_THROW_IF_FAILED(
|
||||
d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf())));
|
||||
|
||||
auto context = std::make_shared<Dml::ExecutionContext>(d3d12_device.Get(), dml_device.Get(), cmd_queue.Get());
|
||||
|
||||
// We leak the upload and readback heaps to keep them alive, just like the map
|
||||
auto upload_heap = std::make_unique<Dml::PooledUploadHeap>(d3d12_device.Get(), context).release();
|
||||
auto readback_heap = std::make_unique<Dml::ReadbackHeap>(d3d12_device.Get(), context).release();
|
||||
|
||||
auto dml_allocator = std::make_shared<Dml::BucketizedBufferAllocator>(
|
||||
d3d12_device.Get(),
|
||||
context,
|
||||
CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT),
|
||||
D3D12_HEAP_FLAG_NONE,
|
||||
D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS,
|
||||
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
|
||||
std::make_unique<Dml::DmlCommittedResourceAllocator>(d3d12_device.Get()));
|
||||
dml_allocator->SetDefaultRoundingMode(AllocatorRoundingMode::Enabled);
|
||||
context->SetAllocator(dml_allocator);
|
||||
|
||||
auto context_ptr = context.get();
|
||||
|
||||
ORT_THROW_IF_FAILED(d3d12_device->SetPrivateData(execution_context_guid, sizeof(context_ptr), &context_ptr));
|
||||
ORT_THROW_IF_FAILED(d3d12_device->SetPrivateData(upload_heap_guid, sizeof(upload_heap), &upload_heap));
|
||||
ORT_THROW_IF_FAILED(d3d12_device->SetPrivateData(dml_readback_heap_guid, sizeof(readback_heap), &readback_heap));
|
||||
|
||||
hit = id_to_allocator_map->emplace(id, std::move(dml_allocator)).first;
|
||||
}
|
||||
|
||||
return hit->second;
|
||||
}
|
||||
|
||||
void CpuToDmlMemCpy(void* dst, const void* src, size_t num_bytes) {
|
||||
const auto* allocInfo = static_cast<const Dml::AllocationInfo*>(dst);
|
||||
ID3D12Resource* dst_data = allocInfo->GetResource();
|
||||
|
||||
ComPtr<ID3D12Device> d3d12_device;
|
||||
ORT_THROW_IF_FAILED(dst_data->GetDevice(IID_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf())));
|
||||
|
||||
Dml::ExecutionContext* context = nullptr;
|
||||
uint32_t context_size = gsl::narrow_cast<uint32_t>(sizeof(context));
|
||||
ORT_THROW_IF_FAILED(d3d12_device->GetPrivateData(execution_context_guid, &context_size, &context));
|
||||
|
||||
Dml::PooledUploadHeap* upload_heap = nullptr;
|
||||
uint32_t upload_heap_size = gsl::narrow_cast<uint32_t>(sizeof(upload_heap));
|
||||
ORT_THROW_IF_FAILED(d3d12_device->GetPrivateData(upload_heap_guid, &upload_heap_size, &upload_heap));
|
||||
|
||||
upload_heap->BeginUploadToGpu(
|
||||
dst_data, 0, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, gsl::make_span(static_cast<const std::byte*>(src), num_bytes));
|
||||
context->Flush();
|
||||
|
||||
// We don't use the same command queue as the execution provider, so we need to sync to make sure that all data has
|
||||
// been uploaded to the resource. This function is usually called before inference just to upload initial data to the
|
||||
// GPU, so it shouldn't be a bottleneck.
|
||||
context->GetCurrentCompletionEvent().WaitForSignal();
|
||||
}
|
||||
|
||||
void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes) {
|
||||
const auto* allocInfo = static_cast<const Dml::AllocationInfo*>(src);
|
||||
ID3D12Resource* src_data = allocInfo->GetResource();
|
||||
|
||||
ComPtr<ID3D12Device> d3d12_device;
|
||||
ORT_THROW_IF_FAILED(src_data->GetDevice(IID_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf())));
|
||||
|
||||
Dml::ExecutionContext* context = nullptr;
|
||||
uint32_t context_size = gsl::narrow_cast<uint32_t>(sizeof(context));
|
||||
ORT_THROW_IF_FAILED(d3d12_device->GetPrivateData(execution_context_guid, &context_size, &context));
|
||||
|
||||
Dml::ReadbackHeap* readback_heap = nullptr;
|
||||
uint32_t readback_heap_size = gsl::narrow_cast<uint32_t>(sizeof(readback_heap));
|
||||
ORT_THROW_IF_FAILED(d3d12_device->GetPrivateData(dml_readback_heap_guid, &readback_heap_size, &readback_heap));
|
||||
|
||||
// ReadbackFromGpu already syncs with the CPU and waits for the copy to be completed, so we don't need to sync after
|
||||
// this call
|
||||
readback_heap->ReadbackFromGpu(
|
||||
gsl::make_span(static_cast<std::byte*>(dst), num_bytes), src_data, 0, D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
|
||||
}
|
||||
|
||||
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* GetDmlToHostMemCpyFunction() {
|
||||
static std::unordered_map<OrtDevice::DeviceType, MemCpyFunc> map{
|
||||
{OrtDevice::GPU, DmlToCpuMemCpy}};
|
||||
|
||||
return ↦
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef USE_CANN
|
||||
|
|
|
|||
|
|
@ -77,6 +77,12 @@ std::unique_ptr<IDataTransfer> GetGPUDataTransfer();
|
|||
|
||||
AllocatorPtr GetDmlAllocator(OrtDevice::DeviceId id);
|
||||
|
||||
void CpuToDmlMemCpy(void* dst, const void* src, size_t num_bytes);
|
||||
|
||||
void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes);
|
||||
|
||||
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* GetDmlToHostMemCpyFunction();
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef USE_CANN
|
||||
|
|
|
|||
|
|
@ -63,7 +63,12 @@ void addOrtValueMethods(pybind11::module& m) {
|
|||
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
|
||||
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA
|
||||
CreateGenericMLValue(nullptr, GetRocmAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToRocmMemCpy);
|
||||
|
||||
#elif USE_DML
|
||||
// InputDeflist is null because OrtValue creation is not tied to a specific model
|
||||
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
|
||||
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML
|
||||
CreateGenericMLValue(
|
||||
nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy);
|
||||
#else
|
||||
throw std::runtime_error(
|
||||
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
|
||||
|
|
@ -126,6 +131,12 @@ void addOrtValueMethods(pybind11::module& m) {
|
|||
values_type,
|
||||
*(ml_value->GetMutable<Tensor>()),
|
||||
CpuToRocmMemCpy);
|
||||
#elif USE_DML
|
||||
onnxruntime::python::CopyDataToTensor(
|
||||
py_values,
|
||||
values_type,
|
||||
*(ml_value->GetMutable<Tensor>()),
|
||||
CpuToDmlMemCpy);
|
||||
#else
|
||||
throw std::runtime_error(
|
||||
"Unsupported GPU device: Cannot find the supported GPU device.");
|
||||
|
|
@ -158,12 +169,18 @@ void addOrtValueMethods(pybind11::module& m) {
|
|||
throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine.");
|
||||
}
|
||||
allocator = GetCudaAllocator(device.Id());
|
||||
#elif USE_DML
|
||||
allocator = GetDmlAllocator(device.Id());
|
||||
#else
|
||||
throw std::runtime_error(
|
||||
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
|
||||
"Please use the CUDA package of OnnxRuntime to use this feature.");
|
||||
#endif
|
||||
} else if (strcmp(GetDeviceName(device), DML) == 0) {
|
||||
#if USE_DML
|
||||
allocator = GetDmlAllocator(device.Id());
|
||||
#else
|
||||
throw std::runtime_error(
|
||||
"Can't allocate memory on the DirectML device using this package of OnnxRuntime. "
|
||||
"Please use the DirectML package of OnnxRuntime to use this feature.");
|
||||
#endif
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device");
|
||||
|
|
@ -290,11 +307,13 @@ void addOrtValueMethods(pybind11::module& m) {
|
|||
#ifdef USE_CUDA
|
||||
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, GetCudaToHostMemCpyFunction());
|
||||
#elif USE_ROCM
|
||||
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, GetRocmToHostMemCpyFunction());
|
||||
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, GetRocmToHostMemCpyFunction());
|
||||
#elif USE_CANN
|
||||
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, GetCannToHostMemCpyFunction());
|
||||
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, GetCannToHostMemCpyFunction());
|
||||
#elif USE_DML
|
||||
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, GetDmlToHostMemCpyFunction());
|
||||
#else
|
||||
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, nullptr);
|
||||
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, nullptr);
|
||||
#endif
|
||||
return obj;
|
||||
})
|
||||
|
|
|
|||
|
|
@ -237,7 +237,11 @@ const char* GetDeviceName(const OrtDevice& device) {
|
|||
case OrtDevice::CPU:
|
||||
return CPU;
|
||||
case OrtDevice::GPU:
|
||||
#ifdef USE_DML
|
||||
return DML;
|
||||
#else
|
||||
return CUDA;
|
||||
#endif
|
||||
case OrtDevice::FPGA:
|
||||
return "FPGA";
|
||||
case OrtDevice::NPU:
|
||||
|
|
|
|||
|
|
@ -16,40 +16,43 @@ from onnxruntime.capi._pybind_state import OrtDevice as C_OrtDevice # pylint: d
|
|||
from onnxruntime.capi._pybind_state import OrtValue as C_OrtValue
|
||||
from onnxruntime.capi._pybind_state import OrtValueVector, SessionIOBinding
|
||||
|
||||
test_params = [
|
||||
("cuda", "CUDAExecutionProvider", C_OrtDevice.cuda),
|
||||
("dml", "DmlExecutionProvider", C_OrtDevice.dml),
|
||||
]
|
||||
|
||||
|
||||
class TestIOBinding(unittest.TestCase):
|
||||
def create_ortvalue_input_on_gpu(self):
|
||||
def _create_ortvalue_input_on_gpu(self, device):
|
||||
return onnxrt.OrtValue.ortvalue_from_numpy(
|
||||
np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32), "cuda", 0
|
||||
np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32), device, 0
|
||||
)
|
||||
|
||||
def create_ortvalue_alternate_input_on_gpu(self):
|
||||
def _create_ortvalue_alternate_input_on_gpu(self, device):
|
||||
return onnxrt.OrtValue.ortvalue_from_numpy(
|
||||
np.array([[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]], dtype=np.float32),
|
||||
"cuda",
|
||||
device,
|
||||
0,
|
||||
)
|
||||
|
||||
def create_uninitialized_ortvalue_input_on_gpu(self):
|
||||
return onnxrt.OrtValue.ortvalue_from_shape_and_type([3, 2], np.float32, "cuda", 0)
|
||||
def _create_uninitialized_ortvalue_input_on_gpu(self, device):
|
||||
return onnxrt.OrtValue.ortvalue_from_shape_and_type([3, 2], np.float32, device, 0)
|
||||
|
||||
def create_numpy_input(self):
|
||||
def _create_numpy_input(self):
|
||||
return np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)
|
||||
|
||||
def create_expected_output(self):
|
||||
def _create_expected_output(self):
|
||||
return np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32)
|
||||
|
||||
def create_expected_output_alternate(self):
|
||||
def _create_expected_output_alternate(self):
|
||||
return np.array([[2.0, 8.0], [18.0, 32.0], [50.0, 72.0]], dtype=np.float32)
|
||||
|
||||
def test_bind_input_to_cpu_arr(self):
|
||||
self.create_numpy_input()
|
||||
|
||||
session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers())
|
||||
io_binding = session.io_binding()
|
||||
|
||||
# Bind Numpy object (input) that's on CPU to wherever the model needs it
|
||||
io_binding.bind_cpu_input("X", self.create_numpy_input())
|
||||
io_binding.bind_cpu_input("X", self._create_numpy_input())
|
||||
|
||||
# Bind output to CPU
|
||||
io_binding.bind_output("Y")
|
||||
|
|
@ -57,254 +60,280 @@ class TestIOBinding(unittest.TestCase):
|
|||
# Invoke Run
|
||||
session.run_with_iobinding(io_binding)
|
||||
|
||||
# Sync if different CUDA streams
|
||||
# Sync if different streams
|
||||
io_binding.synchronize_outputs()
|
||||
|
||||
# Get outputs over to CPU (the outputs which were bound to CUDA will get copied over to the host here)
|
||||
# Get outputs over to CPU (the outputs which were bound to the GPU will get copied over to the host here)
|
||||
ort_output = io_binding.copy_outputs_to_cpu()[0]
|
||||
|
||||
# Validate results
|
||||
self.assertTrue(np.array_equal(self.create_expected_output(), ort_output))
|
||||
self.assertTrue(np.array_equal(self._create_expected_output(), ort_output))
|
||||
|
||||
@unittest.skip("Could not find an implementation for Identity(19) node with name ''")
|
||||
def test_bind_input_types(self):
|
||||
opset = onnx_opset_version()
|
||||
devices = [
|
||||
(
|
||||
C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0),
|
||||
["CPUExecutionProvider"],
|
||||
)
|
||||
]
|
||||
if "CUDAExecutionProvider" in onnxrt.get_all_providers():
|
||||
devices.append(
|
||||
(
|
||||
C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0),
|
||||
["CUDAExecutionProvider"],
|
||||
)
|
||||
)
|
||||
for device, execution_provider, generate_device in test_params:
|
||||
with self.subTest(execution_provider):
|
||||
if execution_provider not in onnxrt.get_available_providers():
|
||||
self.skipTest(f"Skipping on {device.upper()}.")
|
||||
|
||||
for device, provider in devices:
|
||||
for dtype in [
|
||||
np.float32,
|
||||
np.float64,
|
||||
np.int32,
|
||||
np.uint32,
|
||||
np.int64,
|
||||
np.uint64,
|
||||
np.int16,
|
||||
np.uint16,
|
||||
np.int8,
|
||||
np.uint8,
|
||||
np.float16,
|
||||
np.bool_,
|
||||
]:
|
||||
with self.subTest(dtype=dtype, device=str(device)):
|
||||
x = np.arange(8).reshape((-1, 2)).astype(dtype)
|
||||
proto_dtype = NP_TYPE_TO_TENSOR_TYPE[x.dtype]
|
||||
opset = onnx_opset_version()
|
||||
devices = [
|
||||
(
|
||||
C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0),
|
||||
["CPUExecutionProvider"],
|
||||
),
|
||||
(
|
||||
C_OrtDevice(generate_device(), C_OrtDevice.default_memory(), 0),
|
||||
[execution_provider],
|
||||
),
|
||||
]
|
||||
|
||||
X = helper.make_tensor_value_info("X", proto_dtype, [None, x.shape[1]]) # noqa: N806
|
||||
Y = helper.make_tensor_value_info("Y", proto_dtype, [None, x.shape[1]]) # noqa: N806
|
||||
for inner_device, provider in devices:
|
||||
for dtype in [
|
||||
np.float32,
|
||||
np.float64,
|
||||
np.int32,
|
||||
np.uint32,
|
||||
np.int64,
|
||||
np.uint64,
|
||||
np.int16,
|
||||
np.uint16,
|
||||
np.int8,
|
||||
np.uint8,
|
||||
np.float16,
|
||||
np.bool_,
|
||||
]:
|
||||
with self.subTest(dtype=dtype, inner_device=str(inner_device)):
|
||||
x = np.arange(8).reshape((-1, 2)).astype(dtype)
|
||||
proto_dtype = NP_TYPE_TO_TENSOR_TYPE[x.dtype]
|
||||
|
||||
# inference
|
||||
node_add = helper.make_node("Identity", ["X"], ["Y"])
|
||||
X = helper.make_tensor_value_info("X", proto_dtype, [None, x.shape[1]]) # noqa: N806
|
||||
Y = helper.make_tensor_value_info("Y", proto_dtype, [None, x.shape[1]]) # noqa: N806
|
||||
|
||||
# graph
|
||||
graph_def = helper.make_graph([node_add], "lr", [X], [Y], [])
|
||||
model_def = helper.make_model(
|
||||
graph_def,
|
||||
producer_name="dummy",
|
||||
ir_version=7,
|
||||
producer_version="0",
|
||||
opset_imports=[helper.make_operatorsetid("", opset)],
|
||||
)
|
||||
# inference
|
||||
node_add = helper.make_node("Identity", ["X"], ["Y"])
|
||||
|
||||
sess = onnxrt.InferenceSession(model_def.SerializeToString(), providers=provider)
|
||||
# graph
|
||||
graph_def = helper.make_graph([node_add], "lr", [X], [Y], [])
|
||||
model_def = helper.make_model(
|
||||
graph_def,
|
||||
producer_name="dummy",
|
||||
ir_version=7,
|
||||
producer_version="0",
|
||||
opset_imports=[helper.make_operatorsetid("", opset)],
|
||||
)
|
||||
|
||||
bind = SessionIOBinding(sess._sess)
|
||||
ort_value = C_OrtValue.ortvalue_from_numpy(x, device)
|
||||
bind.bind_ortvalue_input("X", ort_value)
|
||||
bind.bind_output("Y", device)
|
||||
sess._sess.run_with_iobinding(bind, None)
|
||||
ortvaluevector = bind.get_outputs()
|
||||
self.assertIsInstance(ortvaluevector, OrtValueVector)
|
||||
ortvalue = bind.get_outputs()[0]
|
||||
y = ortvalue.numpy()
|
||||
assert_almost_equal(x, y)
|
||||
sess = onnxrt.InferenceSession(model_def.SerializeToString(), providers=provider)
|
||||
|
||||
bind = SessionIOBinding(sess._sess)
|
||||
bind.bind_input("X", device, dtype, x.shape, ort_value.data_ptr())
|
||||
bind.bind_output("Y", device)
|
||||
sess._sess.run_with_iobinding(bind, None)
|
||||
ortvalue = bind.get_outputs()[0]
|
||||
y = ortvalue.numpy()
|
||||
assert_almost_equal(x, y)
|
||||
bind = SessionIOBinding(sess._sess)
|
||||
ort_value = C_OrtValue.ortvalue_from_numpy(x, inner_device)
|
||||
bind.bind_ortvalue_input("X", ort_value)
|
||||
bind.bind_output("Y", inner_device)
|
||||
sess._sess.run_with_iobinding(bind, None)
|
||||
ortvaluevector = bind.get_outputs()
|
||||
self.assertIsInstance(ortvaluevector, OrtValueVector)
|
||||
ortvalue = bind.get_outputs()[0]
|
||||
y = ortvalue.numpy()
|
||||
assert_almost_equal(x, y)
|
||||
|
||||
bind = SessionIOBinding(sess._sess)
|
||||
bind.bind_input("X", inner_device, dtype, x.shape, ort_value.data_ptr())
|
||||
bind.bind_output("Y", inner_device)
|
||||
sess._sess.run_with_iobinding(bind, None)
|
||||
ortvalue = bind.get_outputs()[0]
|
||||
y = ortvalue.numpy()
|
||||
assert_almost_equal(x, y)
|
||||
|
||||
def test_bind_input_only(self):
|
||||
input = self.create_ortvalue_input_on_gpu()
|
||||
for device, execution_provider, _ in test_params:
|
||||
with self.subTest(execution_provider):
|
||||
if execution_provider not in onnxrt.get_available_providers():
|
||||
self.skipTest(f"Skipping on {device.upper()}.")
|
||||
input = self._create_ortvalue_input_on_gpu(device)
|
||||
|
||||
session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers())
|
||||
io_binding = session.io_binding()
|
||||
session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers())
|
||||
io_binding = session.io_binding()
|
||||
|
||||
# Bind input to CUDA
|
||||
io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr())
|
||||
# Bind input to the GPU
|
||||
io_binding.bind_input("X", device, 0, np.float32, [3, 2], input.data_ptr())
|
||||
|
||||
# Sync if different CUDA streams
|
||||
io_binding.synchronize_inputs()
|
||||
# Sync if different streams
|
||||
io_binding.synchronize_inputs()
|
||||
|
||||
# Bind output to CPU
|
||||
io_binding.bind_output("Y")
|
||||
# Bind output to CPU
|
||||
io_binding.bind_output("Y")
|
||||
|
||||
# Invoke Run
|
||||
session.run_with_iobinding(io_binding)
|
||||
# Invoke Run
|
||||
session.run_with_iobinding(io_binding)
|
||||
|
||||
# Sync if different CUDA streams
|
||||
io_binding.synchronize_outputs()
|
||||
# Sync if different streams
|
||||
io_binding.synchronize_outputs()
|
||||
|
||||
# Get outputs over to CPU (the outputs which were bound to CUDA will get copied over to the host here)
|
||||
ort_output = io_binding.copy_outputs_to_cpu()[0]
|
||||
# Get outputs over to CPU (the outputs which were bound to the GPU will get copied over to the host
|
||||
# here)
|
||||
ort_output = io_binding.copy_outputs_to_cpu()[0]
|
||||
|
||||
# Validate results
|
||||
self.assertTrue(np.array_equal(self.create_expected_output(), ort_output))
|
||||
# Validate results
|
||||
self.assertTrue(np.array_equal(self._create_expected_output(), ort_output))
|
||||
|
||||
def test_bind_input_and_preallocated_output(self):
|
||||
input = self.create_ortvalue_input_on_gpu()
|
||||
for device, execution_provider, _ in test_params:
|
||||
with self.subTest(execution_provider):
|
||||
if execution_provider not in onnxrt.get_available_providers():
|
||||
self.skipTest(f"Skipping on {device.upper()}.")
|
||||
|
||||
session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers())
|
||||
io_binding = session.io_binding()
|
||||
input = self._create_ortvalue_input_on_gpu(device)
|
||||
|
||||
# Bind input to CUDA
|
||||
io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr())
|
||||
session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers())
|
||||
io_binding = session.io_binding()
|
||||
|
||||
# Bind output to CUDA
|
||||
output = self.create_uninitialized_ortvalue_input_on_gpu()
|
||||
io_binding.bind_output("Y", "cuda", 0, np.float32, [3, 2], output.data_ptr())
|
||||
# Bind input to the GPU
|
||||
io_binding.bind_input("X", device, 0, np.float32, [3, 2], input.data_ptr())
|
||||
|
||||
# Sync if different CUDA streams
|
||||
io_binding.synchronize_inputs()
|
||||
# Bind output to the GPU
|
||||
output = self._create_uninitialized_ortvalue_input_on_gpu(device)
|
||||
io_binding.bind_output("Y", device, 0, np.float32, [3, 2], output.data_ptr())
|
||||
|
||||
# Invoke Run
|
||||
session.run_with_iobinding(io_binding)
|
||||
# Sync if different streams
|
||||
io_binding.synchronize_inputs()
|
||||
|
||||
# Sync if different CUDA streams
|
||||
io_binding.synchronize_outputs()
|
||||
# Invoke Run
|
||||
session.run_with_iobinding(io_binding)
|
||||
|
||||
# Get outputs over to CPU (the outputs which were bound to CUDA will get copied over to the host here)
|
||||
ort_output_vals = io_binding.copy_outputs_to_cpu()[0]
|
||||
# Validate results
|
||||
self.assertTrue(np.array_equal(self.create_expected_output(), ort_output_vals))
|
||||
# Sync if different streams
|
||||
io_binding.synchronize_outputs()
|
||||
|
||||
# Validate if ORT actually wrote to pre-allocated buffer by copying the Torch allocated buffer
|
||||
# to the host and validating its contents
|
||||
ort_output_vals_in_cpu = output.numpy()
|
||||
# Validate results
|
||||
self.assertTrue(np.array_equal(self.create_expected_output(), ort_output_vals_in_cpu))
|
||||
# Get outputs over to CPU (the outputs which were bound to the GPU will get copied over to the host
|
||||
# here)
|
||||
ort_output_vals = io_binding.copy_outputs_to_cpu()[0]
|
||||
# Validate results
|
||||
self.assertTrue(np.array_equal(self._create_expected_output(), ort_output_vals))
|
||||
|
||||
# Validate if ORT actually wrote to pre-allocated buffer by copying the allocated buffer
|
||||
# to the host and validating its contents
|
||||
ort_output_vals_in_cpu = output.numpy()
|
||||
# Validate results
|
||||
self.assertTrue(np.array_equal(self._create_expected_output(), ort_output_vals_in_cpu))
|
||||
|
||||
def test_bind_input_and_non_preallocated_output(self):
|
||||
session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers())
|
||||
io_binding = session.io_binding()
|
||||
for device, execution_provider, _ in test_params:
|
||||
with self.subTest(execution_provider):
|
||||
if execution_provider not in onnxrt.get_available_providers():
|
||||
self.skipTest(f"Skipping on {device.upper()}.")
|
||||
|
||||
# Bind input to CUDA
|
||||
io_binding.bind_input(
|
||||
"X",
|
||||
"cuda",
|
||||
0,
|
||||
np.float32,
|
||||
[3, 2],
|
||||
self.create_ortvalue_input_on_gpu().data_ptr(),
|
||||
)
|
||||
session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers())
|
||||
io_binding = session.io_binding()
|
||||
|
||||
# Bind output to CUDA
|
||||
io_binding.bind_output("Y", "cuda")
|
||||
input = self._create_ortvalue_input_on_gpu(device)
|
||||
|
||||
# Sync if different CUDA streams
|
||||
io_binding.synchronize_inputs()
|
||||
# Bind input to the GPU
|
||||
io_binding.bind_input(
|
||||
"X",
|
||||
device,
|
||||
0,
|
||||
np.float32,
|
||||
[3, 2],
|
||||
input.data_ptr(),
|
||||
)
|
||||
|
||||
# Invoke Run
|
||||
session.run_with_iobinding(io_binding)
|
||||
# Bind output to the GPU
|
||||
io_binding.bind_output("Y", device)
|
||||
|
||||
# Sync if different CUDA streams
|
||||
io_binding.synchronize_outputs()
|
||||
# Sync if different streams
|
||||
io_binding.synchronize_inputs()
|
||||
|
||||
# This call returns an OrtValue which has data allocated by ORT on CUDA
|
||||
ort_outputs = io_binding.get_outputs()
|
||||
self.assertEqual(len(ort_outputs), 1)
|
||||
self.assertEqual(ort_outputs[0].device_name(), "cuda")
|
||||
# Validate results (by copying results to CPU by creating a Numpy object)
|
||||
self.assertTrue(np.array_equal(self.create_expected_output(), ort_outputs[0].numpy()))
|
||||
# Invoke Run
|
||||
session.run_with_iobinding(io_binding)
|
||||
|
||||
# We should be able to repeat the above process as many times as we want - try once more
|
||||
ort_outputs = io_binding.get_outputs()
|
||||
self.assertEqual(len(ort_outputs), 1)
|
||||
self.assertEqual(ort_outputs[0].device_name(), "cuda")
|
||||
# Validate results (by copying results to CPU by creating a Numpy object)
|
||||
self.assertTrue(np.array_equal(self.create_expected_output(), ort_outputs[0].numpy()))
|
||||
# Sync if different streams
|
||||
io_binding.synchronize_outputs()
|
||||
|
||||
# Change the bound input and validate the results in the same bound OrtValue
|
||||
# Bind alternate input to CUDA
|
||||
io_binding.bind_input(
|
||||
"X",
|
||||
"cuda",
|
||||
0,
|
||||
np.float32,
|
||||
[3, 2],
|
||||
self.create_ortvalue_alternate_input_on_gpu().data_ptr(),
|
||||
)
|
||||
# This call returns an OrtValue which has data allocated by ORT on the GPU
|
||||
ort_outputs = io_binding.get_outputs()
|
||||
self.assertEqual(len(ort_outputs), 1)
|
||||
self.assertEqual(ort_outputs[0].device_name(), device)
|
||||
# Validate results (by copying results to CPU by creating a Numpy object)
|
||||
self.assertTrue(np.array_equal(self._create_expected_output(), ort_outputs[0].numpy()))
|
||||
|
||||
# Sync if different CUDA streams
|
||||
io_binding.synchronize_inputs()
|
||||
# We should be able to repeat the above process as many times as we want - try once more
|
||||
ort_outputs = io_binding.get_outputs()
|
||||
self.assertEqual(len(ort_outputs), 1)
|
||||
self.assertEqual(ort_outputs[0].device_name(), device)
|
||||
# Validate results (by copying results to CPU by creating a Numpy object)
|
||||
self.assertTrue(np.array_equal(self._create_expected_output(), ort_outputs[0].numpy()))
|
||||
|
||||
# Invoke Run
|
||||
session.run_with_iobinding(io_binding)
|
||||
input = self._create_ortvalue_alternate_input_on_gpu(device)
|
||||
|
||||
# Sync if different CUDA streams
|
||||
io_binding.synchronize_outputs()
|
||||
# Change the bound input and validate the results in the same bound OrtValue
|
||||
# Bind alternate input to the GPU
|
||||
io_binding.bind_input(
|
||||
"X",
|
||||
device,
|
||||
0,
|
||||
np.float32,
|
||||
[3, 2],
|
||||
input.data_ptr(),
|
||||
)
|
||||
|
||||
# This call returns an OrtValue which has data allocated by ORT on CUDA
|
||||
ort_outputs = io_binding.get_outputs()
|
||||
self.assertEqual(len(ort_outputs), 1)
|
||||
self.assertEqual(ort_outputs[0].device_name(), "cuda")
|
||||
# Validate results (by copying results to CPU by creating a Numpy object)
|
||||
self.assertTrue(np.array_equal(self.create_expected_output_alternate(), ort_outputs[0].numpy()))
|
||||
# Sync if different streams
|
||||
io_binding.synchronize_inputs()
|
||||
|
||||
# Invoke Run
|
||||
session.run_with_iobinding(io_binding)
|
||||
|
||||
# Sync if different streams
|
||||
io_binding.synchronize_outputs()
|
||||
|
||||
# This call returns an OrtValue which has data allocated by ORT on the GPU
|
||||
ort_outputs = io_binding.get_outputs()
|
||||
self.assertEqual(len(ort_outputs), 1)
|
||||
self.assertEqual(ort_outputs[0].device_name(), device)
|
||||
# Validate results (by copying results to CPU by creating a Numpy object)
|
||||
self.assertTrue(np.array_equal(self._create_expected_output_alternate(), ort_outputs[0].numpy()))
|
||||
|
||||
def test_bind_input_and_bind_output_with_ortvalues(self):
|
||||
session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers())
|
||||
io_binding = session.io_binding()
|
||||
for device, execution_provider, _ in test_params:
|
||||
with self.subTest(execution_provider):
|
||||
if execution_provider not in onnxrt.get_available_providers():
|
||||
self.skipTest(f"Skipping on {device.upper()}.")
|
||||
|
||||
# Bind ortvalue as input
|
||||
input_ortvalue = self.create_ortvalue_input_on_gpu()
|
||||
io_binding.bind_ortvalue_input("X", input_ortvalue)
|
||||
session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers())
|
||||
io_binding = session.io_binding()
|
||||
|
||||
# Bind ortvalue as output
|
||||
output_ortvalue = self.create_uninitialized_ortvalue_input_on_gpu()
|
||||
io_binding.bind_ortvalue_output("Y", output_ortvalue)
|
||||
# Bind ortvalue as input
|
||||
input_ortvalue = self._create_ortvalue_input_on_gpu(device)
|
||||
io_binding.bind_ortvalue_input("X", input_ortvalue)
|
||||
|
||||
# Sync if different CUDA streams
|
||||
io_binding.synchronize_inputs()
|
||||
# Bind ortvalue as output
|
||||
output_ortvalue = self._create_uninitialized_ortvalue_input_on_gpu(device)
|
||||
io_binding.bind_ortvalue_output("Y", output_ortvalue)
|
||||
|
||||
# Invoke Run
|
||||
session.run_with_iobinding(io_binding)
|
||||
# Sync if different streams
|
||||
io_binding.synchronize_inputs()
|
||||
|
||||
# Sync if different CUDA streams
|
||||
io_binding.synchronize_outputs()
|
||||
# Invoke Run
|
||||
session.run_with_iobinding(io_binding)
|
||||
|
||||
# Inspect contents of output_ortvalue and make sure that it has the right contents
|
||||
self.assertTrue(np.array_equal(self.create_expected_output(), output_ortvalue.numpy()))
|
||||
# Sync if different streams
|
||||
io_binding.synchronize_outputs()
|
||||
|
||||
# Bind another ortvalue as input
|
||||
input_ortvalue_2 = self.create_ortvalue_alternate_input_on_gpu()
|
||||
io_binding.bind_ortvalue_input("X", input_ortvalue_2)
|
||||
# Inspect contents of output_ortvalue and make sure that it has the right contents
|
||||
self.assertTrue(np.array_equal(self._create_expected_output(), output_ortvalue.numpy()))
|
||||
|
||||
# Sync if different CUDA streams
|
||||
io_binding.synchronize_inputs()
|
||||
# Bind another ortvalue as input
|
||||
input_ortvalue_2 = self._create_ortvalue_alternate_input_on_gpu(device)
|
||||
io_binding.bind_ortvalue_input("X", input_ortvalue_2)
|
||||
|
||||
# Invoke Run
|
||||
session.run_with_iobinding(io_binding)
|
||||
# Sync if different streams
|
||||
io_binding.synchronize_inputs()
|
||||
|
||||
# Sync if different CUDA streams
|
||||
io_binding.synchronize_outputs()
|
||||
# Invoke Run
|
||||
session.run_with_iobinding(io_binding)
|
||||
|
||||
# Inspect contents of output_ortvalue and make sure that it has the right contents
|
||||
self.assertTrue(np.array_equal(self.create_expected_output_alternate(), output_ortvalue.numpy()))
|
||||
# Sync if different streams
|
||||
io_binding.synchronize_outputs()
|
||||
|
||||
# Inspect contents of output_ortvalue and make sure that it has the right contents
|
||||
self.assertTrue(np.array_equal(self._create_expected_output_alternate(), output_ortvalue.numpy()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1840,13 +1840,12 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs):
|
|||
[sys.executable, "onnxruntime_test_python_symbolic_shape_infer.py"], cwd=cwd, dll_path=dll_path
|
||||
)
|
||||
|
||||
# For CUDA enabled builds test IOBinding feature
|
||||
if args.use_cuda:
|
||||
# We need to have Torch installed to test the IOBinding feature
|
||||
# which currently uses Torch's allocator to allocate GPU memory for testing
|
||||
# For CUDA or DML enabled builds test IOBinding feature
|
||||
if args.use_cuda or args.use_dml:
|
||||
log.info("Testing IOBinding feature")
|
||||
run_subprocess([sys.executable, "onnxruntime_test_python_iobinding.py"], cwd=cwd, dll_path=dll_path)
|
||||
|
||||
if args.use_cuda:
|
||||
log.info("Testing CUDA Graph feature")
|
||||
run_subprocess([sys.executable, "onnxruntime_test_python_cudagraph.py"], cwd=cwd, dll_path=dll_path)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue