2018-11-20 00:48:22 +00:00
|
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
|
// Licensed under the MIT License.
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include "core/graph/basic_types.h"
|
|
|
|
|
#include "core/framework/allocator.h"
|
2019-01-07 22:11:46 +00:00
|
|
|
#include "core/framework/data_types.h"
|
2019-01-17 18:51:23 +00:00
|
|
|
#include "core/framework/framework_common.h"
|
2019-01-29 09:48:10 +00:00
|
|
|
#include "core/framework/iexecutor.h"
|
2019-01-17 18:51:23 +00:00
|
|
|
#include "core/framework/session_state.h"
|
2019-10-14 16:48:19 +00:00
|
|
|
#include "core/framework/session_options.h"
|
2018-11-20 00:48:22 +00:00
|
|
|
|
2019-09-26 03:44:33 +00:00
|
|
|
namespace ONNX_NAMESPACE {
|
|
|
|
|
class TensorShapeProto;
|
|
|
|
|
class TensorProto;
|
|
|
|
|
std::ostream& operator<<(std::ostream& out, const TensorShapeProto& shape_proto);
|
|
|
|
|
std::ostream& operator<<(std::ostream& out, const TensorProto& tensor_proto);
|
|
|
|
|
} // namespace ONNX_NAMESPACE
|
|
|
|
|
|
2018-11-20 00:48:22 +00:00
|
|
|
namespace onnxruntime {
|
|
|
|
|
class ExecutionProviders;
|
2019-09-10 05:46:00 +00:00
|
|
|
struct FeedsFetchesInfo;
|
2019-02-20 02:12:17 +00:00
|
|
|
class FeedsFetchesManager;
|
2019-09-10 05:46:00 +00:00
|
|
|
struct MLValueCopyInfo;
|
2019-01-17 18:51:23 +00:00
|
|
|
class Graph;
|
2018-11-20 00:48:22 +00:00
|
|
|
class KernelDef;
|
|
|
|
|
class KernelRegistryManager;
|
2019-01-17 18:51:23 +00:00
|
|
|
class IExecutionProvider;
|
|
|
|
|
class Node;
|
|
|
|
|
class Tensor;
|
2018-11-20 00:48:22 +00:00
|
|
|
|
|
|
|
|
namespace logging {
|
|
|
|
|
class Logger;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
namespace utils {
|
2019-08-23 19:06:35 +00:00
|
|
|
void* DefaultAlloc(size_t size);
|
|
|
|
|
void DefaultFree(void* p);
|
2018-11-20 00:48:22 +00:00
|
|
|
|
2019-01-17 18:51:23 +00:00
|
|
|
const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info);
|
|
|
|
|
|
2019-05-17 14:52:59 +00:00
|
|
|
common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name,
|
|
|
|
|
const OrtValue& orig_mlvalue, OrtValue& new_mlvalue);
|
2019-01-17 18:51:23 +00:00
|
|
|
|
2019-09-10 05:46:00 +00:00
|
|
|
// Searches the allocation plan from the session_state to find the OrtMemoryInfo for the value 'name'.
|
|
|
|
|
const OrtMemoryInfo& FindMemoryInfoForValue(const SessionState& session_state,
|
|
|
|
|
const std::string& name);
|
|
|
|
|
|
|
|
|
|
// Initialize the feed and fetch copy info using session_state.
|
|
|
|
|
// Determines the device that each graph input that will be fed will be consumed on,
|
|
|
|
|
// and the device that each graph output that will be fetched will be created on.
|
|
|
|
|
common::Status InitializeFeedFetchCopyInfo(const SessionState& session_state,
|
|
|
|
|
FeedsFetchesManager& feeds_fetches_manager);
|
|
|
|
|
|
|
|
|
|
// Finalize the feed and fetch copy info using session_state and the device and location information from the feeds
|
|
|
|
|
// and fetches that will be used in graph execution.
|
2020-06-28 04:55:42 +00:00
|
|
|
void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager,
|
2019-09-10 05:46:00 +00:00
|
|
|
const std::vector<OrtDevice>& feed_locations,
|
|
|
|
|
const std::vector<const OrtMemoryInfo*>& fetch_alloc_info);
|
|
|
|
|
|
|
|
|
|
// Execute the main graph. The feed_fetches_manager will be finalized based on the provided feeds and fetches.
|
2019-05-17 14:52:59 +00:00
|
|
|
common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
|
|
|
|
|
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
2020-03-11 21:25:37 +00:00
|
|
|
ExecutionMode execution_mode, const bool& terminate_flag, const logging::Logger& logger,
|
|
|
|
|
bool only_execute_path_to_fetches = false);
|
2019-09-10 05:46:00 +00:00
|
|
|
|
|
|
|
|
// Execute a subgraph. The feeds_fetches_manager should have been finalized prior to calling this function.
|
|
|
|
|
// See IControlFlowNode::SetupSubgraphExecutionInfo usage in the control flow kernels.
|
|
|
|
|
common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFetchesManager& feeds_fetches_manager,
|
|
|
|
|
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
|
|
|
|
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
|
2019-10-14 16:48:19 +00:00
|
|
|
ExecutionMode execution_mode, const bool& terminate_flag, const logging::Logger& logger);
|
2019-01-07 22:11:46 +00:00
|
|
|
|
2019-11-09 01:47:06 +00:00
|
|
|
template <typename T>
|
|
|
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() {
|
|
|
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<bool>() {
|
|
|
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<std::string>() {
|
|
|
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<float>() {
|
|
|
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<double>() {
|
|
|
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<MLFloat16>() {
|
|
|
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<BFloat16>() {
|
|
|
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<int8_t>() {
|
|
|
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<uint8_t>() {
|
|
|
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<int16_t>() {
|
|
|
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<uint16_t>() {
|
|
|
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<int32_t>() {
|
|
|
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<uint32_t>() {
|
|
|
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<int64_t>() {
|
|
|
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<uint64_t>() {
|
|
|
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
|
|
|
|
|
}
|
2019-01-07 22:11:46 +00:00
|
|
|
|
2019-12-05 00:04:17 +00:00
|
|
|
int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType);
|
|
|
|
|
|
2020-11-02 07:05:46 +00:00
|
|
|
common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context);
|
|
|
|
|
|
2018-11-20 00:48:22 +00:00
|
|
|
} // namespace utils
|
|
|
|
|
} // namespace onnxruntime
|