onnxruntime/include/onnxruntime/core/session/environment.h
cao lei dd72192cf4
ExecutionProvider API refactor - move allocator from EP level to SessionState level and indexed by OrtDevice (#15833)
### Description
This PR is to refactor ExecutionProvider API for memory management,
which is to move allocators from EP level to SessionState level and
indexed by OrtDevice



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This PR is to refactor ExecutionProvider API for memory management,
which is to move allocators from EP level to SessionState level and
indexed by OrtDevice. By this change, EP level will shift the burden of
maintaining allocators, which will be user friendly for EP developers

---------

Co-authored-by: Lei Cao <leca@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
2023-06-19 17:44:45 -07:00

103 lines
4 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <atomic>
#include <memory>
#include "core/common/common.h"
#include "core/common/status.h"
#include "core/platform/threadpool.h"
#include "core/common/logging/logging.h"
#include "core/framework/allocator.h"
struct OrtThreadingOptions;
namespace onnxruntime {
/** TODO: remove this class
Provides the runtime environment for onnxruntime.
Create one instance for the duration of execution.
*/
class Environment {
public:
/**
Create and initialize the runtime environment.
@param logging manager instance that will enable per session logger output using
session_options.session_logid as the logger id in messages.
If nullptr, the default LoggingManager MUST have been created previously as it will be used
for logging. This will use the default logger id in messages.
See core/common/logging/logging.h for details, and how LoggingManager::DefaultLogger works.
@param tp_options optional set of parameters controlling the number of intra and inter op threads for the global
threadpools.
@param create_global_thread_pools determine if this function will create the global threadpools or not.
*/
static Status Create(std::unique_ptr<logging::LoggingManager> logging_manager,
std::unique_ptr<Environment>& environment,
const OrtThreadingOptions* tp_options = nullptr,
bool create_global_thread_pools = false);
logging::LoggingManager* GetLoggingManager() const {
return logging_manager_.get();
}
void SetLoggingManager(std::unique_ptr<onnxruntime::logging::LoggingManager> logging_manager) {
logging_manager_ = std::move(logging_manager);
}
onnxruntime::concurrency::ThreadPool* GetIntraOpThreadPool() const {
return intra_op_thread_pool_.get();
}
onnxruntime::concurrency::ThreadPool* GetInterOpThreadPool() const {
return inter_op_thread_pool_.get();
}
bool EnvCreatedWithGlobalThreadPools() const {
return create_global_thread_pools_;
}
/**
* Registers an allocator for sharing between multiple sessions.
* Return an error if an allocator with the same OrtMemoryInfo is already registered.
*/
Status RegisterAllocator(AllocatorPtr allocator);
/**
* Creates and registers an allocator for sharing between multiple sessions.
* Return an error if an allocator with the same OrtMemoryInfo is already registered.
*/
Status CreateAndRegisterAllocator(const OrtMemoryInfo& mem_info, const OrtArenaCfg* arena_cfg = nullptr);
/**
* Returns the list of registered allocators in this env.
*/
const std::vector<AllocatorPtr>& GetRegisteredSharedAllocators() const {
return shared_allocators_;
}
/**
* Removes registered allocator that was previously registered for sharing between multiple sessions.
*/
Status UnregisterAllocator(const OrtMemoryInfo& mem_info);
Environment() = default;
/**
* Create and register an allocator, specified by provider_type, for sharing between multiple sessions.
* Return an error if an allocator with the same OrtMemoryInfo is already registered.
* For provider_type please refer core/graph/constants.h
*/
Status CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo& mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg = nullptr);
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment);
Status Initialize(std::unique_ptr<logging::LoggingManager> logging_manager,
const OrtThreadingOptions* tp_options = nullptr,
bool create_global_thread_pools = false);
std::unique_ptr<logging::LoggingManager> logging_manager_;
std::unique_ptr<onnxruntime::concurrency::ThreadPool> intra_op_thread_pool_;
std::unique_ptr<onnxruntime::concurrency::ThreadPool> inter_op_thread_pool_;
bool create_global_thread_pools_{false};
std::vector<AllocatorPtr> shared_allocators_;
};
} // namespace onnxruntime