mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
96 lines
3.5 KiB
C++
96 lines
3.5 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#pragma once
|
|
|
|
#include <atomic>
|
|
#include <memory>
|
|
#include "core/common/common.h"
|
|
#include "core/common/status.h"
|
|
#include "core/platform/threadpool.h"
|
|
#include "core/common/logging/logging.h"
|
|
#include "core/framework/allocator.h"
|
|
|
|
struct OrtThreadingOptions;
|
|
namespace onnxruntime {
|
|
/** TODO: remove this class
|
|
Provides the runtime environment for onnxruntime.
|
|
Create one instance for the duration of execution.
|
|
*/
|
|
class Environment {
|
|
public:
|
|
/**
|
|
Create and initialize the runtime environment.
|
|
@param logging manager instance that will enable per session logger output using
|
|
session_options.session_logid as the logger id in messages.
|
|
If nullptr, the default LoggingManager MUST have been created previously as it will be used
|
|
for logging. This will use the default logger id in messages.
|
|
See core/common/logging/logging.h for details, and how LoggingManager::DefaultLogger works.
|
|
@param tp_options optional set of parameters controlling the number of intra and inter op threads for the global
|
|
threadpools.
|
|
@param create_global_thread_pools determine if this function will create the global threadpools or not.
|
|
*/
|
|
static Status Create(std::unique_ptr<logging::LoggingManager> logging_manager,
|
|
std::unique_ptr<Environment>& environment,
|
|
const OrtThreadingOptions* tp_options = nullptr,
|
|
bool create_global_thread_pools = false);
|
|
|
|
logging::LoggingManager* GetLoggingManager() const {
|
|
return logging_manager_.get();
|
|
}
|
|
|
|
void SetLoggingManager(std::unique_ptr<onnxruntime::logging::LoggingManager> logging_manager) {
|
|
logging_manager_ = std::move(logging_manager);
|
|
}
|
|
|
|
onnxruntime::concurrency::ThreadPool* GetIntraOpThreadPool() const {
|
|
return intra_op_thread_pool_.get();
|
|
}
|
|
|
|
onnxruntime::concurrency::ThreadPool* GetInterOpThreadPool() const {
|
|
return inter_op_thread_pool_.get();
|
|
}
|
|
|
|
bool EnvCreatedWithGlobalThreadPools() const {
|
|
return create_global_thread_pools_;
|
|
}
|
|
|
|
/**
|
|
* Registers an allocator for sharing between multiple sessions.
|
|
* Return an error if an allocator with the same OrtMemoryInfo is already registered.
|
|
*/
|
|
Status RegisterAllocator(AllocatorPtr allocator);
|
|
|
|
/**
|
|
* Creates and registers an allocator for sharing between multiple sessions.
|
|
* Return an error if an allocator with the same OrtMemoryInfo is already registered.
|
|
*/
|
|
Status CreateAndRegisterAllocator(const OrtMemoryInfo& mem_info, const OrtArenaCfg* arena_cfg = nullptr);
|
|
|
|
/**
|
|
* Returns the list of registered allocators in this env.
|
|
*/
|
|
const std::vector<AllocatorPtr>& GetRegisteredSharedAllocators() const {
|
|
return shared_allocators_;
|
|
}
|
|
|
|
/**
|
|
* Removes registered allocator that was previously registered for sharing between multiple sessions.
|
|
*/
|
|
Status UnregisterAllocator(const OrtMemoryInfo& mem_info);
|
|
|
|
private:
|
|
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment);
|
|
|
|
Environment() = default;
|
|
Status Initialize(std::unique_ptr<logging::LoggingManager> logging_manager,
|
|
const OrtThreadingOptions* tp_options = nullptr,
|
|
bool create_global_thread_pools = false);
|
|
|
|
std::unique_ptr<logging::LoggingManager> logging_manager_;
|
|
std::unique_ptr<onnxruntime::concurrency::ThreadPool> intra_op_thread_pool_;
|
|
std::unique_ptr<onnxruntime::concurrency::ThreadPool> inter_op_thread_pool_;
|
|
bool create_global_thread_pools_{false};
|
|
std::vector<AllocatorPtr> shared_allocators_;
|
|
};
|
|
} // namespace onnxruntime
|