mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-23 02:38:28 +00:00
92 lines
3.2 KiB
C++
92 lines
3.2 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#pragma once
|
|
|
|
// #include <map>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
|
|
#include "core/framework/execution_provider.h"
|
|
#include "core/graph/graph_viewer.h"
|
|
#include "core/common/logging/logging.h"
|
|
|
|
namespace onnxruntime {
|
|
|
|
/**
|
|
Class for managing lookup of the execution providers in a session.
|
|
*/
|
|
class ExecutionProviders {
|
|
public:
|
|
ExecutionProviders() = default;
|
|
|
|
common::Status Add(const std::string& provider_id, std::unique_ptr<IExecutionProvider> p_exec_provider) {
|
|
// make sure there are no issues before we change any internal data structures
|
|
if (provider_idx_map_.find(provider_id) != provider_idx_map_.end()) {
|
|
auto status = ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Provider ", provider_id, " has already been registered.");
|
|
LOGS_DEFAULT(ERROR) << status.ErrorMessage();
|
|
return status;
|
|
}
|
|
|
|
for (const auto& allocator : p_exec_provider->GetAllocatorMap()) {
|
|
if (allocator_idx_map_.find(allocator->Info()) != allocator_idx_map_.end()) {
|
|
auto status = ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, allocator->Info(), " allocator already registered.");
|
|
LOGS_DEFAULT(ERROR) << status.ErrorMessage();
|
|
return status;
|
|
}
|
|
}
|
|
|
|
// index that provider will have after insertion
|
|
auto new_provider_idx = exec_providers_.size();
|
|
|
|
ONNXRUNTIME_IGNORE_RETURN_VALUE(provider_idx_map_.insert({provider_id, new_provider_idx}));
|
|
|
|
for (const auto& allocator : p_exec_provider->GetAllocatorMap()) {
|
|
ONNXRUNTIME_IGNORE_RETURN_VALUE(allocator_idx_map_.insert({allocator->Info(), new_provider_idx}));
|
|
}
|
|
|
|
exec_providers_.push_back(std::move(p_exec_provider));
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
const IExecutionProvider* Get(const onnxruntime::Node& node) const {
|
|
return Get(node.GetExecutionProviderType());
|
|
}
|
|
|
|
const IExecutionProvider* Get(onnxruntime::ProviderType provider_id) const {
|
|
auto it = provider_idx_map_.find(provider_id);
|
|
if (it == provider_idx_map_.end()) {
|
|
return nullptr;
|
|
}
|
|
|
|
return exec_providers_[it->second].get();
|
|
}
|
|
|
|
const IExecutionProvider* Get(const ONNXRuntimeAllocatorInfo& allocator_info) const {
|
|
auto it = allocator_idx_map_.find(allocator_info);
|
|
if (it == allocator_idx_map_.end()) {
|
|
return nullptr;
|
|
}
|
|
|
|
return exec_providers_[it->second].get();
|
|
}
|
|
|
|
bool Empty() const { return exec_providers_.empty(); }
|
|
|
|
using const_iterator = typename std::vector<std::unique_ptr<IExecutionProvider>>::const_iterator;
|
|
const_iterator begin() const noexcept { return exec_providers_.cbegin(); }
|
|
const_iterator end() const noexcept { return exec_providers_.cend(); }
|
|
|
|
private:
|
|
std::vector<std::unique_ptr<IExecutionProvider>> exec_providers_;
|
|
|
|
// maps for fast lookup of an index into exec_providers_
|
|
std::unordered_map<std::string, size_t> provider_idx_map_;
|
|
// using std::map as ONNXRuntimeAllocatorInfo would need a custom hash function to be used with unordered_map,
|
|
// and as this isn't performance critical it's not worth the maintenance overhead of adding one.
|
|
std::map<ONNXRuntimeAllocatorInfo, size_t> allocator_idx_map_;
|
|
};
|
|
} // namespace onnxruntime
|