mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
support user_compute_stream for rocm ep (#19619)
### Description <!-- Describe your changes. --> According to the pr #19229 supporting cuda EP use external compute stream, we add support for rocm EP. And when we testing this feature with torch, we found torch use stream 0 for the default stream, and `torch.cuda.current_stream()` returns `0` for current stream, but ort treat `0` or `nullptr` as invalid, and reset has_user_compute_stream to false. Will remove has_user_compute_stream option in the future. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> The motivation for this pr is that we want to use torch.cuda.graph to capture ort running kernel, which requires torch and ort are running in the same stream, so we use this API to set ort's working stream.
This commit is contained in:
parent
8a71b65765
commit
6f566562ce
2 changed files with 30 additions and 0 deletions
|
|
@ -13,6 +13,8 @@ namespace onnxruntime {
|
|||
namespace rocm {
|
||||
namespace provider_option_names {
|
||||
constexpr const char* kDeviceId = "device_id";
|
||||
constexpr const char* kHasUserComputeStream = "has_user_compute_stream";
|
||||
constexpr const char* kUserComputeStream = "user_compute_stream";
|
||||
constexpr const char* kMemLimit = "gpu_mem_limit";
|
||||
constexpr const char* kArenaExtendStrategy = "arena_extend_strategy";
|
||||
constexpr const char* kMiopenConvExhaustiveSearch = "miopen_conv_exhaustive_search";
|
||||
|
|
@ -38,6 +40,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P
|
|||
void* alloc = nullptr;
|
||||
void* free = nullptr;
|
||||
void* empty_cache = nullptr;
|
||||
void* user_compute_stream = nullptr;
|
||||
ORT_THROW_IF_ERROR(
|
||||
ProviderOptionsParser{}
|
||||
.AddValueParser(
|
||||
|
|
@ -52,6 +55,15 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P
|
|||
", must be between 0 (inclusive) and ", num_devices, " (exclusive).");
|
||||
return Status::OK();
|
||||
})
|
||||
.AddAssignmentToReference(rocm::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream)
|
||||
.AddValueParser(
|
||||
rocm::provider_option_names::kUserComputeStream,
|
||||
[&user_compute_stream](const std::string& value_str) -> Status {
|
||||
size_t address;
|
||||
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
|
||||
user_compute_stream = reinterpret_cast<void*>(address);
|
||||
return Status::OK();
|
||||
})
|
||||
.AddValueParser(
|
||||
rocm::provider_option_names::kGpuExternalAlloc,
|
||||
[&alloc](const std::string& value_str) -> Status {
|
||||
|
|
@ -108,12 +120,18 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P
|
|||
|
||||
ROCMExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache};
|
||||
info.external_allocator_info = alloc_info;
|
||||
|
||||
info.user_compute_stream = user_compute_stream;
|
||||
info.has_user_compute_stream = (user_compute_stream != nullptr);
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecutionProviderInfo& info) {
|
||||
const ProviderOptions options{
|
||||
{rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
|
||||
{rocm::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)},
|
||||
{rocm::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.user_compute_stream))},
|
||||
{rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)},
|
||||
{rocm::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.alloc))},
|
||||
{rocm::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.free))},
|
||||
|
|
@ -135,6 +153,8 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution
|
|||
ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const OrtROCMProviderOptions& info) {
|
||||
const ProviderOptions options{
|
||||
{rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
|
||||
{rocm::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)},
|
||||
{rocm::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.user_compute_stream))},
|
||||
{rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)},
|
||||
{rocm::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast<onnxruntime::ArenaExtendStrategy>(info.arena_extend_strategy))},
|
||||
{rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)},
|
||||
|
|
|
|||
|
|
@ -559,6 +559,16 @@ class TestInferenceSession(unittest.TestCase):
|
|||
|
||||
test_get_and_set_option_with_values("enable_hip_graph", ["1", "0"])
|
||||
|
||||
# test for user_compute_stream
|
||||
option = options["ROCMExecutionProvider"]
|
||||
option["user_compute_stream"] = "1"
|
||||
sess.set_providers(["ROCMExecutionProvider"], [option])
|
||||
new_options = sess.get_provider_options()
|
||||
new_option = new_options["ROCMExecutionProvider"]
|
||||
self.assertEqual(new_option["user_compute_stream"], "1")
|
||||
# set user_compute_stream will set has_user_compute_stream to 1 too
|
||||
self.assertEqual(new_option["has_user_compute_stream"], "1")
|
||||
|
||||
run_rocm_options_test()
|
||||
|
||||
def test_invalid_set_providers(self):
|
||||
|
|
|
|||
Loading…
Reference in a new issue