support user_compute_stream for rocm ep (#19619)

### Description
<!-- Describe your changes. -->
According to the pr #19229 supporting cuda EP use external compute
stream, we add support for rocm EP.

And when we testing this feature with torch, we found torch use stream 0
for the default stream, and `torch.cuda.current_stream()` returns `0`
for current stream, but ort treat `0` or `nullptr` as invalid, and reset
has_user_compute_stream to false. 

Will remove has_user_compute_stream option in the future.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
The motivation for this pr is that we want to use torch.cuda.graph to
capture ort running kernel, which requires torch and ort are running in
the same stream, so we use this API to set ort's working stream.
This commit is contained in:
kailums 2024-02-27 11:31:03 +08:00 committed by GitHub
parent 8a71b65765
commit 6f566562ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 30 additions and 0 deletions

View file

@ -13,6 +13,8 @@ namespace onnxruntime {
namespace rocm {
namespace provider_option_names {
constexpr const char* kDeviceId = "device_id";
constexpr const char* kHasUserComputeStream = "has_user_compute_stream";
constexpr const char* kUserComputeStream = "user_compute_stream";
constexpr const char* kMemLimit = "gpu_mem_limit";
constexpr const char* kArenaExtendStrategy = "arena_extend_strategy";
constexpr const char* kMiopenConvExhaustiveSearch = "miopen_conv_exhaustive_search";
@ -38,6 +40,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P
void* alloc = nullptr;
void* free = nullptr;
void* empty_cache = nullptr;
void* user_compute_stream = nullptr;
ORT_THROW_IF_ERROR(
ProviderOptionsParser{}
.AddValueParser(
@ -52,6 +55,15 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P
", must be between 0 (inclusive) and ", num_devices, " (exclusive).");
return Status::OK();
})
.AddAssignmentToReference(rocm::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream)
.AddValueParser(
rocm::provider_option_names::kUserComputeStream,
[&user_compute_stream](const std::string& value_str) -> Status {
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
user_compute_stream = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddValueParser(
rocm::provider_option_names::kGpuExternalAlloc,
[&alloc](const std::string& value_str) -> Status {
@ -108,12 +120,18 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P
ROCMExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache};
info.external_allocator_info = alloc_info;
info.user_compute_stream = user_compute_stream;
info.has_user_compute_stream = (user_compute_stream != nullptr);
return info;
}
ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecutionProviderInfo& info) {
const ProviderOptions options{
{rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{rocm::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)},
{rocm::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.user_compute_stream))},
{rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)},
{rocm::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.alloc))},
{rocm::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.free))},
@ -135,6 +153,8 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution
ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const OrtROCMProviderOptions& info) {
const ProviderOptions options{
{rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{rocm::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)},
{rocm::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.user_compute_stream))},
{rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)},
{rocm::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast<onnxruntime::ArenaExtendStrategy>(info.arena_extend_strategy))},
{rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)},

View file

@ -559,6 +559,16 @@ class TestInferenceSession(unittest.TestCase):
test_get_and_set_option_with_values("enable_hip_graph", ["1", "0"])
# test for user_compute_stream
option = options["ROCMExecutionProvider"]
option["user_compute_stream"] = "1"
sess.set_providers(["ROCMExecutionProvider"], [option])
new_options = sess.get_provider_options()
new_option = new_options["ROCMExecutionProvider"]
self.assertEqual(new_option["user_compute_stream"], "1")
# set user_compute_stream will set has_user_compute_stream to 1 too
self.assertEqual(new_option["has_user_compute_stream"], "1")
run_rocm_options_test()
def test_invalid_set_providers(self):