mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
CP 7fd1ce95a4 (#18560)
CP 7fd1ce95a4 for onnxruntime_perf_test
changes.
Co-authored-by: Sheil Kumar <sheilk@microsoft.com>
This commit is contained in:
parent
7f9e6c42c2
commit
e8209ce2b0
10 changed files with 113 additions and 46 deletions
|
|
@ -46,12 +46,12 @@ namespace Dml
|
|||
return GpuEvent{ m_lastFenceValue + 1, m_fence };
|
||||
}
|
||||
|
||||
void CommandQueue::QueueReference(IUnknown* object, bool waitForUnsubmittedWork)
|
||||
void CommandQueue::QueueReference(IUnknown* object, bool waitForUnsubmittedWork)
|
||||
{
|
||||
// If the CommandQueue is closing, then m_queuedReferences is being cleared -- it is not OK
|
||||
// to queue additional references at this time, since those references would be leaked. This
|
||||
// affects any objects in m_queuedReferences whose destructors indirectly call QueueReference;
|
||||
// for example, an allocation from BucketizedBufferAllocator attempts to queue a reference
|
||||
// If the CommandQueue is closing, then m_queuedReferences is being cleared -- it is not OK
|
||||
// to queue additional references at this time, since those references would be leaked. This
|
||||
// affects any objects in m_queuedReferences whose destructors indirectly call QueueReference;
|
||||
// for example, an allocation from BucketizedBufferAllocator attempts to queue a reference
|
||||
// to its underlying D3D resource when freed. Furthermore, these references are unnecessary
|
||||
// since Close() already blocks for scheduled GPU work before clearing m_queuedReferences.
|
||||
if (!m_closing)
|
||||
|
|
@ -68,7 +68,7 @@ namespace Dml
|
|||
m_queuedReferences.push_back(queuedReference);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void CommandQueue::Close()
|
||||
{
|
||||
// Wait for flushed work:
|
||||
|
|
@ -79,7 +79,7 @@ namespace Dml
|
|||
m_queuedReferences.clear();
|
||||
m_closing = false;
|
||||
}
|
||||
|
||||
|
||||
void CommandQueue::ReleaseCompletedReferences()
|
||||
{
|
||||
uint64_t completedValue = GetFence()->GetCompletedValue();
|
||||
|
|
@ -89,5 +89,4 @@ namespace Dml
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ namespace Dml
|
|||
: m_queue(std::make_shared<CommandQueue>(queue))
|
||||
, m_dmlRecorder(d3d12Device, dmlDevice, m_queue)
|
||||
{
|
||||
ORT_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(m_d3dDevice.GetAddressOf())));
|
||||
ORT_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(m_d3dDevice.GetAddressOf())));
|
||||
}
|
||||
|
||||
void ExecutionContext::SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator)
|
||||
|
|
@ -78,14 +78,14 @@ namespace Dml
|
|||
ID3D12GraphicsCommandList* commandList,
|
||||
_Outptr_ ID3D12Fence** fence,
|
||||
_Out_ uint64_t* completionValue
|
||||
)
|
||||
)
|
||||
{
|
||||
assert(!m_closed);
|
||||
|
||||
SetCommandRecorder(&m_dmlRecorder);
|
||||
m_dmlRecorder.ExecuteCommandList(commandList, fence, completionValue);
|
||||
}
|
||||
|
||||
|
||||
void ExecutionContext::InitializeOperator(
|
||||
IDMLCompiledOperator* op,
|
||||
const DML_BINDING_DESC& persistentResourceBinding,
|
||||
|
|
@ -110,7 +110,7 @@ namespace Dml
|
|||
}
|
||||
|
||||
void ExecutionContext::AddUAVBarrier()
|
||||
{
|
||||
{
|
||||
assert(!m_closed);
|
||||
SetCommandRecorder(&m_dmlRecorder);
|
||||
|
||||
|
|
@ -173,9 +173,9 @@ namespace Dml
|
|||
m_currentRecorder = nullptr;
|
||||
SetCommandRecorder(&m_dmlRecorder);
|
||||
}
|
||||
|
||||
void ExecutionContext::QueueReference(IUnknown* object)
|
||||
{
|
||||
|
||||
void ExecutionContext::QueueReference(IUnknown* object)
|
||||
{
|
||||
assert(!m_closed);
|
||||
// If something has been recorded into a command list but not submitted yet, it means that the *next* fence
|
||||
// value is the one to signal completion.
|
||||
|
|
@ -186,14 +186,14 @@ namespace Dml
|
|||
void ExecutionContext::Close()
|
||||
{
|
||||
assert(!m_closed);
|
||||
|
||||
|
||||
// Discard unflushed work and clear queued references. This prevents the circular reference:
|
||||
// Kernel --> ProviderImpl --> Context --> QueuedRefs --> Kernel
|
||||
m_queue->Close();
|
||||
m_currentRecorder = nullptr;
|
||||
m_closed = true;
|
||||
}
|
||||
|
||||
|
||||
GpuEvent ExecutionContext::GetCurrentCompletionEvent()
|
||||
{
|
||||
assert(!m_closed);
|
||||
|
|
|
|||
|
|
@ -20,13 +20,13 @@ namespace Dml
|
|||
public:
|
||||
// Constructs an ExecutionContext that executes on the supplied queue.
|
||||
ExecutionContext(
|
||||
ID3D12Device* d3d12Device,
|
||||
IDMLDevice* dmlDevice,
|
||||
ID3D12Device* d3d12Device,
|
||||
IDMLDevice* dmlDevice,
|
||||
ID3D12CommandQueue* queue);
|
||||
|
||||
void SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator);
|
||||
|
||||
// Waits for flushed work, discards unflushed work, and discards associated references to
|
||||
// Waits for flushed work, discards unflushed work, and discards associated references to
|
||||
// prevent circular references. Must be the last call on the object before destruction.
|
||||
void Close();
|
||||
|
||||
|
|
@ -75,12 +75,12 @@ namespace Dml
|
|||
// Returns an event which will become signaled when everything submitted to the execution context thus far has
|
||||
// completed execution on the GPU, including work that has yet to be flushed to the queue.
|
||||
GpuEvent GetCurrentCompletionEvent();
|
||||
|
||||
|
||||
// Adds a reference which will be released when queued GPU work is completed
|
||||
void QueueReference(IUnknown* object);
|
||||
|
||||
// Release any accumulated references who corresponding GPU fence values have
|
||||
// been reached.
|
||||
// been reached.
|
||||
void ReleaseCompletedReferences();
|
||||
|
||||
D3D12_COMMAND_LIST_TYPE GetCommandListTypeForQueue() const;
|
||||
|
|
|
|||
|
|
@ -205,7 +205,7 @@ namespace Dml
|
|||
D3D12_FEATURE_DATA_D3D12_OPTIONS19 options19 = {};
|
||||
|
||||
// The call may fail in which case the default value is false
|
||||
d3d12Device->CheckFeatureSupport(static_cast<D3D12_FEATURE>(48) /*D3D12_FEATURE_D3D12_OPTIONS19*/, &options19, sizeof(options19));
|
||||
d3d12Device->CheckFeatureSupport(static_cast<D3D12_FEATURE>(48) /*D3D12_FEATURE_D3D12_OPTIONS19*/, &options19, sizeof(options19));
|
||||
m_areCustomHeapsSupported = options19.ComputeOnlyCustomHeapSupported;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -150,7 +150,7 @@ namespace Dml
|
|||
}
|
||||
|
||||
STDMETHOD_(bool, IsMcdmDevice)() const noexcept final;
|
||||
STDMETHOD_(bool, CustomHeapsSupported)() const noexcept final;
|
||||
STDMETHOD_(bool, CustomHeapsSupported)() const noexcept final;
|
||||
|
||||
STDMETHOD_(bool, MetacommandsEnabled)() const noexcept final;
|
||||
bool DynamicGraphFusionEnabled() const noexcept;
|
||||
|
|
|
|||
|
|
@ -118,7 +118,6 @@ static bool IsGPU(IDXCoreAdapter* compute_adapter) {
|
|||
return compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_NPU_ADAPTER_ENUMERATION
|
||||
static bool IsNPU(IDXCoreAdapter* compute_adapter) {
|
||||
// Only considering hardware adapters
|
||||
if (!IsHardwareAdapter(compute_adapter)) {
|
||||
|
|
@ -126,7 +125,6 @@ static bool IsNPU(IDXCoreAdapter* compute_adapter) {
|
|||
}
|
||||
return !(compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS));
|
||||
}
|
||||
#endif
|
||||
|
||||
enum class DeviceType { GPU, NPU, BadDevice };
|
||||
|
||||
|
|
@ -327,7 +325,8 @@ static std::optional<OrtDmlPerformancePreference> ParsePerformancePreference(con
|
|||
}
|
||||
|
||||
static std::optional<OrtDmlDeviceFilter> ParseFilter(const ProviderOptions& provider_options) {
|
||||
static const std::string Filter = "filter";
|
||||
static const std::string Filter = "device_filter";
|
||||
static const std::string Any = "any";
|
||||
static const std::string Gpu = "gpu";
|
||||
#ifdef ENABLE_NPU_ADAPTER_ENUMERATION
|
||||
static const std::string Any = "any";
|
||||
|
|
|
|||
|
|
@ -58,6 +58,10 @@ namespace perftest {
|
|||
"\t-q [CUDA only] use separate stream for copy. \n"
|
||||
"\t-z: Set denormal as zero. When turning on this option reduces latency dramatically, a model may have denormals.\n"
|
||||
"\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n"
|
||||
"\t [DML only] [performance_preference]: DML device performance preference, options: 'default', 'minimum_power', 'high_performance', \n"
|
||||
"\t [DML only] [device_filter]: DML device filter, options: 'any', 'gpu', 'npu', \n"
|
||||
"\t [DML only] [disable_metacommands]: Options: 'true', 'false', \n"
|
||||
"\t [DML only] [enable_dynamic_graph_fusion]: Options: 'true', 'false', \n"
|
||||
"\t [OpenVINO only] [device_type]: Overrides the accelerator hardware type and precision with these values at runtime.\n"
|
||||
"\t [OpenVINO only] [device_id]: Selects a particular hardware device for inference.\n"
|
||||
"\t [OpenVINO only] [enable_npu_fast_compile]: Optionally enabled to speeds up the model's compilation on NPU device targets.\n"
|
||||
|
|
|
|||
|
|
@ -16,6 +16,10 @@
|
|||
#include "providers.h"
|
||||
#include "TestCase.h"
|
||||
|
||||
#ifdef USE_DML
|
||||
#include "core/providers/dml/dml_provider_factory.h"
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
#define strdup _strdup
|
||||
#endif
|
||||
|
|
@ -42,8 +46,8 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
|
|||
const TestModelInfo& m)
|
||||
: rand_engine_(rd()), input_names_(m.GetInputCount()), input_names_str_(m.GetInputCount()), input_length_(m.GetInputCount()) {
|
||||
Ort::SessionOptions session_options;
|
||||
const std::string& provider_name = performance_test_config.machine_config.provider_type_name;
|
||||
if (provider_name == onnxruntime::kDnnlExecutionProvider) {
|
||||
provider_name_ = performance_test_config.machine_config.provider_type_name;
|
||||
if (provider_name_ == onnxruntime::kDnnlExecutionProvider) {
|
||||
#ifdef USE_DNNL
|
||||
// Generate provider options
|
||||
OrtDnnlProviderOptions dnnl_options;
|
||||
|
|
@ -96,7 +100,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
|
|||
#else
|
||||
ORT_THROW("DNNL is not supported in this build\n");
|
||||
#endif
|
||||
} else if (provider_name == onnxruntime::kCudaExecutionProvider) {
|
||||
} else if (provider_name_ == onnxruntime::kCudaExecutionProvider) {
|
||||
#ifdef USE_CUDA
|
||||
const auto& api = Ort::GetApi();
|
||||
OrtCUDAProviderOptionsV2* cuda_options;
|
||||
|
|
@ -161,7 +165,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
|
|||
#else
|
||||
ORT_THROW("CUDA is not supported in this build\n");
|
||||
#endif
|
||||
} else if (provider_name == onnxruntime::kTensorrtExecutionProvider) {
|
||||
} else if (provider_name_ == onnxruntime::kTensorrtExecutionProvider) {
|
||||
#ifdef USE_TENSORRT
|
||||
const auto& api = Ort::GetApi();
|
||||
OrtTensorRTProviderOptionsV2* tensorrt_options;
|
||||
|
|
@ -215,7 +219,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
|
|||
#else
|
||||
ORT_THROW("TensorRT is not supported in this build\n");
|
||||
#endif
|
||||
} else if (provider_name == onnxruntime::kOpenVINOExecutionProvider) {
|
||||
} else if (provider_name_ == onnxruntime::kOpenVINOExecutionProvider) {
|
||||
#ifdef USE_OPENVINO
|
||||
#ifdef _MSC_VER
|
||||
std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string);
|
||||
|
|
@ -251,7 +255,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
|
|||
ov_options[key] = value;
|
||||
} else {
|
||||
ORT_THROW(
|
||||
"[ERROR] [OpenVINO] You have selcted wrong configuration value for the key 'device_type'. "
|
||||
"[ERROR] [OpenVINO] You have selected a wrong configuration value for the key 'device_type'. "
|
||||
"Select from 'CPU_FP32', 'CPU_FP16', 'GPU_FP32', 'GPU.0_FP32', 'GPU.1_FP32', 'GPU_FP16', "
|
||||
"'GPU.0_FP16', 'GPU.1_FP16' or from"
|
||||
" HETERO/MULTI/AUTO options available. \n");
|
||||
|
|
@ -305,7 +309,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
|
|||
#else
|
||||
ORT_THROW("OpenVINO is not supported in this build\n");
|
||||
#endif
|
||||
} else if (provider_name == onnxruntime::kQnnExecutionProvider) {
|
||||
} else if (provider_name_ == onnxruntime::kQnnExecutionProvider) {
|
||||
#ifdef USE_QNN
|
||||
#ifdef _MSC_VER
|
||||
std::string option_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string);
|
||||
|
|
@ -378,7 +382,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
|
|||
#else
|
||||
ORT_THROW("QNN is not supported in this build\n");
|
||||
#endif
|
||||
} else if (provider_name == onnxruntime::kSnpeExecutionProvider) {
|
||||
} else if (provider_name_ == onnxruntime::kSnpeExecutionProvider) {
|
||||
#ifdef USE_SNPE
|
||||
#ifdef _MSC_VER
|
||||
std::string option_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string);
|
||||
|
|
@ -430,7 +434,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
|
|||
#else
|
||||
ORT_THROW("SNPE is not supported in this build\n");
|
||||
#endif
|
||||
} else if (provider_name == onnxruntime::kNnapiExecutionProvider) {
|
||||
} else if (provider_name_ == onnxruntime::kNnapiExecutionProvider) {
|
||||
#ifdef USE_NNAPI
|
||||
uint32_t nnapi_flags = 0;
|
||||
#ifdef _MSC_VER
|
||||
|
|
@ -458,22 +462,81 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
|
|||
#else
|
||||
ORT_THROW("NNAPI is not supported in this build\n");
|
||||
#endif
|
||||
} else if (provider_name == onnxruntime::kCoreMLExecutionProvider) {
|
||||
} else if (provider_name_ == onnxruntime::kCoreMLExecutionProvider) {
|
||||
#ifdef USE_COREML
|
||||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, 0));
|
||||
#else
|
||||
ORT_THROW("COREML is not supported in this build\n");
|
||||
#endif
|
||||
} else if (provider_name == onnxruntime::kDmlExecutionProvider) {
|
||||
} else if (provider_name_ == onnxruntime::kDmlExecutionProvider) {
|
||||
#ifdef USE_DML
|
||||
std::unordered_map<std::string, std::string> dml_options;
|
||||
dml_options["performance_preference"] = "high_performance";
|
||||
dml_options["device_filter"] = "gpu";
|
||||
dml_options["disable_metacommands"] = "false";
|
||||
dml_options["enable_dynamic_graph_fusion"] = "false";
|
||||
#ifdef _MSC_VER
|
||||
std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string);
|
||||
#else
|
||||
std::string ov_string = performance_test_config.run_config.ep_runtime_config_string;
|
||||
#endif
|
||||
std::istringstream ss(ov_string);
|
||||
std::string token;
|
||||
while (ss >> token) {
|
||||
if (token == "") {
|
||||
continue;
|
||||
}
|
||||
auto pos = token.find("|");
|
||||
if (pos == std::string::npos || pos == 0 || pos == token.length()) {
|
||||
ORT_THROW("[ERROR] [DML] Use a '|' to separate the key and value for the run-time option you are trying to use.\n");
|
||||
}
|
||||
|
||||
auto key = token.substr(0, pos);
|
||||
auto value = token.substr(pos + 1);
|
||||
|
||||
if (key == "device_filter") {
|
||||
std::set<std::string> ov_supported_device_types = {"gpu", "npu"};
|
||||
if (ov_supported_device_types.find(value) != ov_supported_device_types.end()) {
|
||||
dml_options[key] = value;
|
||||
} else {
|
||||
ORT_THROW(
|
||||
"[ERROR] [DML] You have selected a wrong configuration value for the key 'device_filter'. "
|
||||
"Select from 'gpu', or 'npu' \n");
|
||||
}
|
||||
} else if (key == "performance_preference") {
|
||||
std::set<std::string> ov_supported_values = {"default", "high_performance", "minimal_power"};
|
||||
if (ov_supported_values.find(value) != ov_supported_values.end()) {
|
||||
dml_options[key] = value;
|
||||
} else {
|
||||
ORT_THROW(
|
||||
"[ERROR] [DML] You have selected a wrong configuration value for the key 'performance_preference'. "
|
||||
"Select from 'default', 'high_performance' or 'minimal_power' \n");
|
||||
}
|
||||
} else if (key == "disable_metacommands") {
|
||||
std::set<std::string> ov_supported_values = {"true", "True", "false", "False"};
|
||||
if (ov_supported_values.find(value) != ov_supported_values.end()) {
|
||||
dml_options[key] = value;
|
||||
} else {
|
||||
ORT_THROW(
|
||||
"[ERROR] [DML] You have selcted wrong value for the key 'disable_metacommands'. "
|
||||
"Select from 'true' or 'false' \n");
|
||||
}
|
||||
} else if (key == "enable_dynamic_graph_fusion") {
|
||||
std::set<std::string> ov_supported_values = {"true", "True", "false", "False"};
|
||||
if (ov_supported_values.find(value) != ov_supported_values.end()) {
|
||||
dml_options[key] = value;
|
||||
} else {
|
||||
ORT_THROW(
|
||||
"[ERROR] [DML] You have selcted wrong value for the key 'enable_dynamic_graph_fusion'. "
|
||||
"Select from 'true' or 'false' \n");
|
||||
}
|
||||
}
|
||||
}
|
||||
session_options.AppendExecutionProvider("DML", dml_options);
|
||||
#else
|
||||
ORT_THROW("DML is not supported in this build\n");
|
||||
#endif
|
||||
} else if (provider_name == onnxruntime::kAclExecutionProvider) {
|
||||
} else if (provider_name_ == onnxruntime::kAclExecutionProvider) {
|
||||
#ifdef USE_ACL
|
||||
Ort::ThrowOnError(
|
||||
OrtSessionOptionsAppendExecutionProvider_ACL(session_options,
|
||||
|
|
@ -481,14 +544,14 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
|
|||
#else
|
||||
ORT_THROW("Acl is not supported in this build\n");
|
||||
#endif
|
||||
} else if (provider_name == onnxruntime::kArmNNExecutionProvider) {
|
||||
} else if (provider_name_ == onnxruntime::kArmNNExecutionProvider) {
|
||||
#ifdef USE_ARMNN
|
||||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ArmNN(session_options,
|
||||
performance_test_config.run_config.enable_cpu_mem_arena ? 1 : 0));
|
||||
#else
|
||||
ORT_THROW("ArmNN is not supported in this build\n");
|
||||
#endif
|
||||
} else if (provider_name == onnxruntime::kRocmExecutionProvider) {
|
||||
} else if (provider_name_ == onnxruntime::kRocmExecutionProvider) {
|
||||
#ifdef USE_ROCM
|
||||
OrtROCMProviderOptions rocm_options;
|
||||
rocm_options.miopen_conv_exhaustive_search = performance_test_config.run_config.cudnn_conv_algo;
|
||||
|
|
@ -498,7 +561,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
|
|||
#else
|
||||
ORT_THROW("ROCM is not supported in this build\n");
|
||||
#endif
|
||||
} else if (provider_name == onnxruntime::kMIGraphXExecutionProvider) {
|
||||
} else if (provider_name_ == onnxruntime::kMIGraphXExecutionProvider) {
|
||||
#ifdef USE_MIGRAPHX
|
||||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(session_options, 0));
|
||||
OrtROCMProviderOptions rocm_options;
|
||||
|
|
@ -508,7 +571,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
|
|||
#else
|
||||
ORT_THROW("MIGraphX is not supported in this build\n");
|
||||
#endif
|
||||
} else if (provider_name == onnxruntime::kXnnpackExecutionProvider) {
|
||||
} else if (provider_name_ == onnxruntime::kXnnpackExecutionProvider) {
|
||||
#ifdef USE_XNNPACK
|
||||
session_options.AddConfigEntry(kOrtSessionOptionsConfigAllowIntraOpSpinning, "0");
|
||||
session_options.AppendExecutionProvider(
|
||||
|
|
@ -516,7 +579,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
|
|||
#else
|
||||
ORT_THROW("Xnnpack is not supported in this build\n");
|
||||
#endif
|
||||
} else if (provider_name == onnxruntime::kVitisAIExecutionProvider) {
|
||||
} else if (provider_name_ == onnxruntime::kVitisAIExecutionProvider) {
|
||||
#ifdef USE_VITISAI
|
||||
#ifdef _MSC_VER
|
||||
std::string option_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string);
|
||||
|
|
@ -544,7 +607,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
|
|||
#else
|
||||
ORT_THROW("VitisAI is not supported in this build\n");
|
||||
#endif
|
||||
} else if (!provider_name.empty() && provider_name != onnxruntime::kCpuExecutionProvider) {
|
||||
} else if (!provider_name_.empty() && provider_name_ != onnxruntime::kCpuExecutionProvider) {
|
||||
ORT_THROW("This backend is not included in perf test runner.\n");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ class OnnxRuntimeTestSession : public TestSession {
|
|||
std::vector<const char*> input_names_;
|
||||
std::vector<std::string> input_names_str_;
|
||||
const int input_length_;
|
||||
std::string provider_name_;
|
||||
};
|
||||
|
||||
} // namespace perftest
|
||||
|
|
|
|||
|
|
@ -274,8 +274,9 @@ std::unique_ptr<IExecutionProvider> DefaultCannExecutionProvider() {
|
|||
|
||||
std::unique_ptr<IExecutionProvider> DefaultDmlExecutionProvider() {
|
||||
#ifdef USE_DML
|
||||
if (auto factory = DMLProviderFactoryCreator::Create(0, false, false, false))
|
||||
if (auto factory = DMLProviderFactoryCreator::CreateFromOptions(nullptr, false, false)) {
|
||||
return factory->CreateProvider();
|
||||
}
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue