onnxruntime/onnxruntime/core/codegen/common/common.h
KeDengMS 0d204f3f06
Implementation of TVM codegen library (#888)
Description:

This change adds the common part of TVM based codegen library. It includes following parts:
* Microsoft TVM Inventory (MTI): a set of TVM ops for neural networks, similar to TOPI
* Compiler pass for traversing ONNX graph and generate TVM ops
* Compiler pass for traversing generated graph and specify TVM schedule
* Compiler pass for handling weight layout
* Utils for debugging

Motivation and Context:

TVM is an open deep learning compiler stack for cpu, gpu and specialized accelerators. To leverage it in ONNX, we built an execution provider named Nuphar. Currently, Nuphar gets good performance on CPUs with AVX2 on quantized LSTM models.

This codegen library was part of Nuphar execution provider. It is split out for sharing with other execution providers, as we'd like to reuse TVM in more devices.
2019-07-03 10:32:59 -07:00

151 lines
6.6 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/framework/compute_capability.h"
#include "core/framework/tensor.h"
#include "core/graph/graph_nodes.h"
#include "core/graph/graph_viewer.h"
#ifndef NDEBUG
#define ORT_ENFORCE_DEBUG(...) ORT_ENFORCE(__VA_ARGS__)
#else
#define ORT_ENFORCE_DEBUG(...)
#endif // !NDEBUG
// DYN_PROMOTE is a simplified llvm::dyn_cast, which does not need RTTI
// DYN_PROMOTE is faster than dynamic_cast and also has smaller binary size
// Please use DYN_PROMOTE in a critical path.
#define DYN_PROMOTE(BASE) \
template <typename ToType> \
inline const ToType* Promote(const BASE* base) { \
if (ToType::IsType(base)) \
return static_cast<const ToType*>(base); \
return nullptr; \
} \
\
template <typename ToType> \
inline ToType* Promote(BASE* base) { \
if (ToType::IsType(base)) \
return static_cast<ToType*>(base); \
return nullptr; \
} \
\
template <typename ToType> \
inline ToType* Promote(const std::unique_ptr<BASE>& base) { \
if (ToType::IsType(base.get())) \
return static_cast<ToType*>(base); \
return nullptr; \
} \
\
template <typename ToType> \
inline ToType* Promote(const std::shared_ptr<BASE>& base) { \
if (ToType::IsType(base.get())) \
return static_cast<ToType*>(base); \
return nullptr; \
}
// DYN_PROMOTE_BASE is a macro inserted in the base class to support DYN_PROMOTE
// TYPE_ID is required for DYN_PROMOTE and TYPE_ID is a enum class
// TYPE_ID_VAR is a corresponding variable name for in the base class
#define DYN_PROMOTE_BASE(BASE, TYPE_ID, TYPE_ID_VAR) \
inline const TYPE_ID TypeID() const { \
return TYPE_ID_VAR; \
} \
\
static inline bool IsType(const BASE*) { \
return true; \
}
// DYN_PROMOTE_DERIVED is a macro inserted in a derived class to support DYN_PROMOTE
// TYPE_ID is required for DYN_PROMOTE and TYPE_ID is a enum class
// TYPE_ID_VALUE is corresponding TYPE_ID::value of a derived class.
#define DYN_PROMOTE_DERIVED(BASE, TYPE_ID, TYPE_ID_VALUE) \
static inline bool IsType(const BASE* base) { \
ORT_ENFORCE_DEBUG(nullptr != base); \
return base->TypeID() == TYPE_ID::TYPE_ID_VALUE; \
}
// DYNAMIC_PROMOTE is a dynamic_cast needing RTTI
// DYNAMIC_PROMOTE is usually slower than than DYN_PROMOTE.
// Please use DYNAMIC_PROMOTE in a non-critical path.
#define DYNAMIC_PROMOTE(BASE) \
template <typename X> \
inline const X* Promote(const BASE* base) { \
auto derived = dynamic_cast<const X*>(base); \
ORT_ENFORCE(nullptr != derived); \
return derived; \
} \
\
template <typename X> \
inline X* Promote(BASE* base) { \
auto derived = dynamic_cast<X*>(base); \
ORT_ENFORCE(nullptr != derived); \
return derived; \
} \
\
template <typename X> \
inline X* Promote(const std::unique_ptr<BASE>& base) { \
auto derived = dynamic_cast<X*>(base.get()); \
ORT_ENFORCE(nullptr != derived); \
return derived; \
} \
\
template <typename X> \
inline X* Promote(const std::shared_ptr<BASE>& base) { \
auto derived = dynamic_cast<X*>(base.get()); \
ORT_ENFORCE(nullptr != derived); \
return derived; \
}
namespace onnxruntime {
// Nodekey is used as a key for maps
using NodeKey = std::string;
NodeKey GetKey(const onnxruntime::Node* node);
NodeKey GetKey(const onnxruntime::Node& node);
NodeKey GetKey(const onnxruntime::NodeArg* def);
bool IsRecurrentNode(const onnxruntime::Node& node);
bool IsAliasNode(const onnxruntime::Node& node);
// Helper function that creates ComputeCapability for subgraphs
std::unique_ptr<ComputeCapability> ToCapacity(const onnxruntime::GraphViewer& graph,
std::unique_ptr<IndexedSubGraph>& subgraph);
bool IsFusedNode(const Node& node);
bool HasLoop(const Node& node);
const Graph* GetSubgraph(const Node& node);
std::string NormalizeCppName(const std::string& name);
std::string NormalizeNodeArgName(const NodeArg* def);
// Return the corresponding input node for the NodeArg of the given node
const onnxruntime::Node* GetInputNode(const Node& node, const NodeArg* def);
int64_t ShapeRank(const NodeArg* def);
bool ShapeHasValue(const NodeArg* def, int i);
bool ShapeHasSymbol(const NodeArg* def, int i);
int64_t ShapeValue(const NodeArg* def, int i);
const std::string& ShapeSymbol(const NodeArg* def, int i);
ONNX_NAMESPACE::TensorProto_DataType TensorProtoDataType(const NodeArg* def);
// Convert GraphNodes to internal NodePtrs without check lifetime.
// Please use it only locally when GraphNodes still exist
std::vector<const Node*> ConvertGraphNodesToNodePtrs(const GraphNodes& graph_nodes);
enum : int {
Dimension_Unknown = -1,
};
} // namespace onnxruntime