mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-22 02:30:26 +00:00
* static allocation.
* chanegs.
* contigious dynamic allocation.
* contigious dynamic allocation.
* fix bugs.
* fix bug.
* build errors.
* PR feedback.
* PR feedback.
* Update Graph builder for nccl_allreduce, mps.
* misc.
* fix windows build break.
* changes.
* fine-grained memory-time scheduling.
* merge.
* fix misc stuff.
* fix windows build.
* fix windows build.
* fix merge bug.
* merge conflicts.
* revert onnx-tensorrt submodule commit.
* fix submodule commit.
* misc.
* merge conflicts.
* Revert "merge conflicts."
This reverts commit 319a071a6e.
* merge conflict.
* merge conflict.
* merge conflicts.
* fixes.
* PR feedback.
* build break.
* build break.
* Add asserts.
* Add asserts.
* asserts.
* asserts.
* asserts.
* asserts.
* asserts.
* fixes.
* fixes.
Co-authored-by: Ubuntu <OrtTrainingDev3@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
Co-authored-by: root <root@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
155 lines
5.7 KiB
C++
155 lines
5.7 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#pragma once
|
|
|
|
#include "core/graph/basic_types.h"
|
|
#include "core/framework/allocator.h"
|
|
#include "core/framework/data_types.h"
|
|
#include "core/framework/framework_common.h"
|
|
#include "core/framework/iexecutor.h"
|
|
#include "core/framework/session_state.h"
|
|
#include "core/framework/session_options.h"
|
|
|
|
namespace ONNX_NAMESPACE {
|
|
class TensorShapeProto;
|
|
class TensorProto;
|
|
std::ostream& operator<<(std::ostream& out, const TensorShapeProto& shape_proto);
|
|
std::ostream& operator<<(std::ostream& out, const TensorProto& tensor_proto);
|
|
} // namespace ONNX_NAMESPACE
|
|
|
|
namespace onnxruntime {
|
|
class ExecutionProviders;
|
|
struct FeedsFetchesInfo;
|
|
class FeedsFetchesManager;
|
|
struct MLValueCopyInfo;
|
|
class Graph;
|
|
class KernelDef;
|
|
class KernelRegistryManager;
|
|
class IExecutionProvider;
|
|
class Node;
|
|
class Tensor;
|
|
|
|
namespace logging {
|
|
class Logger;
|
|
}
|
|
|
|
namespace utils {
|
|
void* DefaultAlloc(size_t size);
|
|
void DefaultFree(void* p);
|
|
|
|
const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info);
|
|
|
|
common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name,
|
|
const OrtValue& orig_mlvalue, OrtValue& new_mlvalue);
|
|
|
|
// Searches the allocation plan from the session_state to find the OrtMemoryInfo for the value 'name'.
|
|
const OrtMemoryInfo& FindMemoryInfoForValue(const SessionState& session_state,
|
|
const std::string& name);
|
|
|
|
// Initialize the feed and fetch copy info using session_state.
|
|
// Determines the device that each graph input that will be fed will be consumed on,
|
|
// and the device that each graph output that will be fetched will be created on.
|
|
common::Status InitializeFeedFetchCopyInfo(const SessionState& session_state,
|
|
FeedsFetchesManager& feeds_fetches_manager);
|
|
|
|
// Finalize the feed and fetch copy info using session_state and the device and location information from the feeds
|
|
// and fetches that will be used in graph execution.
|
|
void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager,
|
|
const std::vector<OrtDevice>& feed_locations,
|
|
const std::vector<const OrtMemoryInfo*>& fetch_alloc_info);
|
|
|
|
// Execute the main graph. The feed_fetches_manager will be finalized based on the provided feeds and fetches.
|
|
common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
|
|
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
|
ExecutionMode execution_mode, const bool& terminate_flag, const logging::Logger& logger,
|
|
bool only_execute_path_to_fetches = false);
|
|
|
|
// Execute a subgraph. The feeds_fetches_manager should have been finalized prior to calling this function.
|
|
// See IControlFlowNode::SetupSubgraphExecutionInfo usage in the control flow kernels.
|
|
common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFetchesManager& feeds_fetches_manager,
|
|
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
|
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
|
|
ExecutionMode execution_mode, const bool& terminate_flag, const logging::Logger& logger);
|
|
|
|
template <typename T>
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() {
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
|
}
|
|
|
|
template <>
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<bool>() {
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
|
}
|
|
|
|
template <>
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<std::string>() {
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
|
}
|
|
|
|
template <>
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<float>() {
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
|
}
|
|
|
|
template <>
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<double>() {
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
|
|
}
|
|
|
|
template <>
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<MLFloat16>() {
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
|
|
}
|
|
|
|
template <>
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<BFloat16>() {
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
|
|
}
|
|
|
|
template <>
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<int8_t>() {
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
|
|
}
|
|
|
|
template <>
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<uint8_t>() {
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
|
}
|
|
|
|
template <>
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<int16_t>() {
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
|
|
}
|
|
|
|
template <>
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<uint16_t>() {
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
|
|
}
|
|
|
|
template <>
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<int32_t>() {
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
|
|
}
|
|
|
|
template <>
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<uint32_t>() {
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
|
|
}
|
|
|
|
template <>
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<int64_t>() {
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
|
}
|
|
|
|
template <>
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<uint64_t>() {
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
|
|
}
|
|
|
|
int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType);
|
|
|
|
common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context);
|
|
|
|
} // namespace utils
|
|
} // namespace onnxruntime
|