onnxruntime/onnxruntime/core/framework/utils.h

156 lines
5.7 KiB
C
Raw Normal View History

2018-11-20 00:48:22 +00:00
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/graph/basic_types.h"
#include "core/framework/allocator.h"
#include "core/framework/data_types.h"
#include "core/framework/framework_common.h"
#include "core/framework/iexecutor.h"
#include "core/framework/session_state.h"
#include "core/framework/session_options.h"
2018-11-20 00:48:22 +00:00
namespace ONNX_NAMESPACE {
class TensorShapeProto;
class TensorProto;
std::ostream& operator<<(std::ostream& out, const TensorShapeProto& shape_proto);
std::ostream& operator<<(std::ostream& out, const TensorProto& tensor_proto);
} // namespace ONNX_NAMESPACE
2018-11-20 00:48:22 +00:00
namespace onnxruntime {
class ExecutionProviders;
struct FeedsFetchesInfo;
Various optimizations to reduce the setup and device copying cost outside of the call to ExecuteGraph. (#470) * Various optimizations to reduce the setup and execution cost. Cache information about the feeds and fetches, and any device copies required to execute the graph so we minimize checking for later calls to ExecuteGraph using the same input/output. - enable use of caching in Loop and Scan - make use of caching optional for InferenceSession::Run - handle calls to Run with different feeds and fetches to support scenarios where there may be a truncated sequence in some calls Take the feed names and MLValue instances as vectors so the order is deterministic. Add unit tests Update onnxruntime_perf_test to enable caching. * Couple of tweaks. Fix shared library unit test failure. Attempt to workaround MacOS build failure due to VC++ bug around including reaching scope values in a lambda automatically. * Rework order of init in Run so we get nice error messages about invalid feed/output names. * Refine logic around copying MLValue using execution provider so common code can be used. Simplify the logic due to this change. Split the paths for executing with/without cached info so we can be more const correct with how FeedsFetchesManager is passed in. This makes it clearer when a shared instance can be used due to it being const. Cache the FeedsFetchesManager instances in the control flow nodes. They can be re-used across calls to Compute. * Removed unused local variable to fix some builds. * Fix build issue by cleaning up some more unused params. * Check names when using cache entry from SessionState. Add unit test.
2019-02-20 02:12:17 +00:00
class FeedsFetchesManager;
struct MLValueCopyInfo;
class Graph;
2018-11-20 00:48:22 +00:00
class KernelDef;
class KernelRegistryManager;
class IExecutionProvider;
class Node;
class Tensor;
2018-11-20 00:48:22 +00:00
namespace logging {
class Logger;
}
namespace utils {
void* DefaultAlloc(size_t size);
void DefaultFree(void* p);
2018-11-20 00:48:22 +00:00
const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info);
common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name,
const OrtValue& orig_mlvalue, OrtValue& new_mlvalue);
// Searches the allocation plan from the session_state to find the OrtMemoryInfo for the value 'name'.
const OrtMemoryInfo& FindMemoryInfoForValue(const SessionState& session_state,
const std::string& name);
// Initialize the feed and fetch copy info using session_state.
// Determines the device that each graph input that will be fed will be consumed on,
// and the device that each graph output that will be fetched will be created on.
common::Status InitializeFeedFetchCopyInfo(const SessionState& session_state,
FeedsFetchesManager& feeds_fetches_manager);
// Finalize the feed and fetch copy info using session_state and the device and location information from the feeds
// and fetches that will be used in graph execution.
void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager,
const std::vector<OrtDevice>& feed_locations,
const std::vector<const OrtMemoryInfo*>& fetch_alloc_info);
// Execute the main graph. The feed_fetches_manager will be finalized based on the provided feeds and fetches.
common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
2020-03-11 21:25:37 +00:00
ExecutionMode execution_mode, const bool& terminate_flag, const logging::Logger& logger,
bool only_execute_path_to_fetches = false);
// Execute a subgraph. The feeds_fetches_manager should have been finalized prior to calling this function.
// See IControlFlowNode::SetupSubgraphExecutionInfo usage in the control flow kernels.
common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFetchesManager& feeds_fetches_manager,
const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
ExecutionMode execution_mode, const bool& terminate_flag, const logging::Logger& logger);
template <typename T>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<bool>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<std::string>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<float>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<double>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<MLFloat16>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<BFloat16>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<int8_t>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<uint8_t>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<int16_t>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<uint16_t>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<int32_t>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<uint32_t>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<int64_t>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}
template <>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType<uint64_t>() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
}
int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType);
common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context);
2018-11-20 00:48:22 +00:00
} // namespace utils
} // namespace onnxruntime