mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
amdmigraphx_ep-add ops to be supported by migraphx and fixed a bug in check ops to be supported (#10496)
* backup debugging information related to debugging a jira ticket * fixed a bug in checking whether an input can be constand folded * added more operators that are supported by migraphx * revert unnecessary changes * remove unused logger parameter * rename function to make name style consistent * backup code changes * fix review comments * refactor graph utility functions to add unit tests * backup additional changes * fixed a link error in build migraphx_basic_test * add unit test for some migraphx utility functions * add more supported ops in migraphx
This commit is contained in:
parent
ae08f9666d
commit
7ee52fb8a0
5 changed files with 440 additions and 190 deletions
|
|
@ -528,6 +528,13 @@ if(onnxruntime_USE_TENSORRT)
|
|||
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_tensorrt onnxruntime_providers_shared)
|
||||
endif()
|
||||
|
||||
if(onnxruntime_USE_MIGRAPHX)
|
||||
list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/migraphx/*)
|
||||
list(APPEND onnxruntime_test_framework_src_patterns "${ONNXRUNTIME_ROOT}/core/providers/migraphx/migraphx_execution_provider_utils.h")
|
||||
list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_migraphx)
|
||||
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx onnxruntime_providers_shared)
|
||||
endif()
|
||||
|
||||
if(onnxruntime_USE_NNAPI_BUILTIN)
|
||||
list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/nnapi/*)
|
||||
list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_nnapi)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,9 @@
|
|||
#define ORT_API_MANUAL_INIT
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/common/logging/severity.h"
|
||||
#include "migraphx_execution_provider.h"
|
||||
#include "migraphx_execution_provider_utils.h"
|
||||
#include "hip_allocator.h"
|
||||
#include "hip_fence.h"
|
||||
#include "gpu_data_transfer.h"
|
||||
|
|
@ -112,6 +114,12 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
|
|||
if (!fp16_enable_env.empty()) {
|
||||
fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true);
|
||||
}
|
||||
|
||||
// dump unsupported ops
|
||||
const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::dumpModelOps);
|
||||
if (!dump_model_ops_env.empty()) {
|
||||
dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true);
|
||||
}
|
||||
}
|
||||
|
||||
AllocatorPtr MIGraphXExecutionProvider::GetAllocator(int id, OrtMemType mem_type) const {
|
||||
|
|
@ -196,7 +204,7 @@ static bool IsTypeSupported(const NodeArg* node_arg) {
|
|||
}
|
||||
}
|
||||
|
||||
static bool get_migraphx_type(ONNXTensorElementDataType type,
|
||||
static bool getMIGraphXType(ONNXTensorElementDataType type,
|
||||
migraphx_shape_datatype_t& mgx_type) {
|
||||
mgx_type = migraphx_shape_float_type;
|
||||
switch (type) {
|
||||
|
|
@ -245,35 +253,8 @@ static bool get_migraphx_type(ONNXTensorElementDataType type,
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool IsGraphInput(const GraphViewer& graph, const std::string& name)
|
||||
{
|
||||
const auto& graph_inputs = graph.GetInputs();
|
||||
std::vector<std::string> input_names(graph_inputs.size());
|
||||
std::transform(graph_inputs.begin(), graph_inputs.end(), input_names.begin(), [](auto in) {
|
||||
return in->Name();
|
||||
});
|
||||
return (std::find(input_names.begin(), input_names.end(), name) != input_names.end());
|
||||
}
|
||||
|
||||
static bool IsGraphInitializer(const GraphViewer& graph, const std::string& name, bool check_outer_scope = true) {
|
||||
const ONNX_NAMESPACE::TensorProto* initializer = nullptr;
|
||||
return graph.GetInitializedTensor(name, initializer);
|
||||
}
|
||||
|
||||
const Node* GetInputNode(const Node& node, int arg_index) {
|
||||
int index = 0;
|
||||
for (auto nit = node.InputNodesBegin(); nit != node.InputNodesEnd(); ++nit, ++index)
|
||||
{
|
||||
if (index == arg_index)
|
||||
{
|
||||
return &(*nit);
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<int> to_vector(const ONNX_NAMESPACE::int64s& nums)
|
||||
std::vector<int> toVector(const ONNX_NAMESPACE::int64s& nums)
|
||||
{
|
||||
std::vector<int> result;
|
||||
int num = nums.size();
|
||||
|
|
@ -285,125 +266,7 @@ std::vector<int> to_vector(const ONNX_NAMESPACE::int64s& nums)
|
|||
return result;
|
||||
}
|
||||
|
||||
std::size_t node_input_num(const Node& node)
|
||||
{
|
||||
std::size_t node_num = 0;
|
||||
for(auto it = node.InputNodesBegin(); it != node.InputNodesEnd(); ++it)
|
||||
{
|
||||
node_num++;
|
||||
}
|
||||
|
||||
return node_num;
|
||||
}
|
||||
|
||||
static bool can_eval_shape_general(const GraphViewer& graph, const Node* node, const logging::Logger& logger, std::vector<NodeIndex>& input_nodes)
|
||||
{
|
||||
if (node == nullptr)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<const Node*> in_nodes;
|
||||
for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit)
|
||||
{
|
||||
in_nodes.push_back(&(*nit));
|
||||
}
|
||||
|
||||
if (node->OpType() == "Shape")
|
||||
{
|
||||
input_nodes.push_back(node->Index());
|
||||
return true;
|
||||
}
|
||||
|
||||
auto inputs = node->InputDefs();
|
||||
for (std::size_t i = 0; i < inputs.size(); ++i)
|
||||
{
|
||||
const std::string& input_name = inputs.at(i)->Name();
|
||||
// If it is an initializer, it can be constant folded
|
||||
if (IsGraphInitializer(graph, input_name))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Input for sure cannot be constant folded
|
||||
if (IsGraphInput(graph, input_name))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// find the node corresponding to the name
|
||||
auto nit = std::find_if(in_nodes.begin(), in_nodes.end(), [&](auto n) {
|
||||
return input_name.find(n->Name()) != std::string::npos;
|
||||
});
|
||||
if (nit == in_nodes.end())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
auto input_node = (*nit);
|
||||
// shape node, it is OK
|
||||
if (input_node->OpType() == "Shape")
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (can_eval_shape_general(graph, input_node, logger, input_nodes))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
input_nodes.push_back(node->Index());
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool can_eval_node_argument(const GraphViewer& graph, const Node* node, std::vector<std::size_t> indices, const logging::Logger& logger, std::vector<NodeIndex>& input_nodes)
|
||||
{
|
||||
input_nodes.clear();
|
||||
|
||||
std::vector<const Node*> in_nodes;
|
||||
for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit)
|
||||
{
|
||||
in_nodes.push_back(&(*nit));
|
||||
}
|
||||
|
||||
auto inputs = node->InputDefs();
|
||||
for (auto index : indices)
|
||||
{
|
||||
// an initializer itself is a constant
|
||||
auto input_name = inputs.at(index)->Name();
|
||||
if (IsGraphInitializer(graph, input_name))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Input cannot be constant folded
|
||||
if (IsGraphInput(graph, input_name))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// find the node corresponding to the name
|
||||
auto nit = std::find_if(in_nodes.begin(), in_nodes.end(), [&](auto n) {
|
||||
return input_name.find(n->Name()) != std::string::npos;
|
||||
});
|
||||
if (nit == in_nodes.end())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!can_eval_shape_general(graph, *nit, logger, input_nodes))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, const Node* node, const logging::Logger& logger) {
|
||||
static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, const Node* node) {
|
||||
std::vector<NodeIndex> input_nodes;
|
||||
const auto& optype = node->OpType();
|
||||
if (optype == "ArgMax" or optype == "ArgMin") {
|
||||
|
|
@ -414,7 +277,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
|
|||
return true;
|
||||
}
|
||||
} else if (optype == "ConstantOfShape") {
|
||||
if (!can_eval_node_argument(graph_viewer, node, {0}, logger, input_nodes))
|
||||
if (!canEvalNodeArgument(graph_viewer, node, {0}, input_nodes))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
|
@ -439,7 +302,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
|
|||
}
|
||||
} else if (optype == "Expand") {
|
||||
// MIGraphX only supports constant shape input values
|
||||
if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
|
||||
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
|
@ -454,7 +317,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
|
|||
const auto& attributes = node->GetAttributes();
|
||||
auto dila_attr = attributes.find("dilations");
|
||||
if (dila_attr != attributes.end()) {
|
||||
auto dilas = to_vector((*dila_attr).second.ints());
|
||||
auto dilas = toVector((*dila_attr).second.ints());
|
||||
bool ret = std::all_of(dilas.begin(), dilas.end(), [](auto i) { return i == 1; });
|
||||
if (ret == false) {
|
||||
return true;
|
||||
|
|
@ -493,12 +356,12 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
|
|||
return true;
|
||||
}
|
||||
} else if (optype == "NonZero") {
|
||||
if (!can_eval_node_argument(graph_viewer, node, {0}, logger, input_nodes))
|
||||
if (!canEvalNodeArgument(graph_viewer, node, {0}, input_nodes))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
} else if (optype == "OneHot") {
|
||||
if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
|
||||
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
|
@ -506,7 +369,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
|
|||
const auto& args = node->InputDefs();
|
||||
// if pad size is not constant, migraphx cannot support
|
||||
if (args.size() >= 2) {
|
||||
if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
|
||||
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
|
@ -527,7 +390,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
|
|||
// input value only applied to constant mode
|
||||
if (mode == "constant") {
|
||||
if (args.size() == 3) {
|
||||
if (!can_eval_node_argument(graph_viewer, node, {2}, logger, input_nodes))
|
||||
if (!canEvalNodeArgument(graph_viewer, node, {2}, input_nodes))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
|
@ -537,20 +400,20 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
|
|||
auto arg_num = node->InputDefs().size();
|
||||
std::vector<std::size_t> vec(arg_num);
|
||||
std::iota(vec.begin(), vec.end(), 0);
|
||||
if (!can_eval_node_argument(graph_viewer, node, vec, logger, input_nodes))
|
||||
if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
} else if (optype == "Reshape") {
|
||||
const auto& args = node->InputDefs();
|
||||
if (args.size() == 2) {
|
||||
if (can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
|
||||
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} else if (optype == "Resize") {
|
||||
} else if (optype == "Resize" or optype == "Upsample") {
|
||||
const auto& attributes = node->GetAttributes();
|
||||
auto ct_attr = attributes.find("coordinate_transformation_mode");
|
||||
if (ct_attr != attributes.end()) {
|
||||
|
|
@ -575,7 +438,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
|
|||
{
|
||||
std::vector<std::size_t> indices(args.size() - 1);
|
||||
std::iota(indices.begin(), indices.end(), 1);
|
||||
if (can_eval_node_argument(graph_viewer, node, indices, logger, input_nodes))
|
||||
if (canEvalNodeArgument(graph_viewer, node, indices, input_nodes))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
|
@ -584,7 +447,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
|
|||
} else if (optype == "ReduceSum") {
|
||||
const auto& args = node->InputDefs();
|
||||
if (args.size() == 2) {
|
||||
if (can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
|
||||
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
|
@ -598,15 +461,15 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
|
|||
std::vector<std::size_t> vec(arg_num);
|
||||
std::iota(vec.begin(), vec.end(), 0);
|
||||
vec.erase(vec.begin());
|
||||
if (!can_eval_node_argument(graph_viewer, node, vec, logger, input_nodes))
|
||||
if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
const auto& attributes = node->GetAttributes();
|
||||
if (attributes.count("starts") > 0 and attributes.count("ends") > 0) {
|
||||
auto starts = to_vector((*attributes.find("starts")).second.ints());
|
||||
auto ends = to_vector((*attributes.find("ends")).second.ints());
|
||||
auto starts = toVector((*attributes.find("starts")).second.ints());
|
||||
auto ends = toVector((*attributes.find("ends")).second.ints());
|
||||
for (std::size_t i = 0; i < starts.size(); ++i) {
|
||||
if (starts.at(i) > ends.at(i)) {
|
||||
return true;
|
||||
|
|
@ -636,26 +499,26 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
|
|||
|
||||
const auto& args = node->InputDefs();
|
||||
if (args.size() == 2) {
|
||||
if (can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
|
||||
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} else if (optype == "Tile") {
|
||||
if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
|
||||
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
} else if (optype == "TopK") {
|
||||
if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
|
||||
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
} else if (optype == "Unsqueeze" or optype == "Squeeze") {
|
||||
const auto& args = node->InputDefs();
|
||||
if (args.size() == 2) {
|
||||
if (can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
|
||||
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
|
@ -690,8 +553,7 @@ void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::v
|
|||
const auto& args = node->InputDefs();
|
||||
if (args.size() == 2) {
|
||||
std::vector<NodeIndex> node_inputs;
|
||||
// if (can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, node_inputs))
|
||||
if (can_eval_node_argument(graph_viewer, node, {1}, logger, node_inputs))
|
||||
if (canEvalNodeArgument(graph_viewer, node, {1}, node_inputs))
|
||||
{
|
||||
return (not std::all_of(node_inputs.begin(), node_inputs.end(), [&](auto index) {
|
||||
return std::find(git.begin(), git.end(), index) != git.end();
|
||||
|
|
@ -796,7 +658,7 @@ static bool IsNodeSupported(const std::set<std::string>& op_set,
|
|||
}
|
||||
|
||||
// check that some modes might not be supported in migraphx for some operators
|
||||
if (domain == kOnnxDomain && IsUnsupportedOpMode(graph_viewer, node, logger)) {
|
||||
if (domain == kOnnxDomain && IsUnsupportedOpMode(graph_viewer, node)) {
|
||||
// not supported, then check the constant folding capability of migraphx
|
||||
// to see whether it is supported
|
||||
return false;
|
||||
|
|
@ -959,8 +821,6 @@ std::unique_ptr<IndexedSubGraph> MIGraphXExecutionProvider::GetSubGraph(const st
|
|||
output_names.push_back(name);
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Generate unique kernel name for MIGraphX subgraph
|
||||
uint64_t model_hash = 0;
|
||||
int id = GenerateMetaDefId(graph, model_hash);
|
||||
|
|
@ -992,21 +852,22 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer,
|
|||
/*out*/ std::unordered_set<std::string>& mgx_required_initializers,
|
||||
const logging::Logger& logger) {
|
||||
static std::set<std::string> mgx_supported_ops = {"Abs", "Acos", "Acosh", "Add", "And",
|
||||
"ArgMax", "ArgMin", "Asin", "Asinh", "Atan", "Atanh", "AveragePool",
|
||||
"BatchNormalization", "Cast", "Ceil", "Clip", "Concat", "Constant", "ConstantFill",
|
||||
"ConstantOfShape", "Conv", "Cos", "Cosh", "DepthToSpace", "DequantizeLinear", "Div",
|
||||
"Dropout", "Elu", "Equal", "Erf", "Exp", "Expand", "Flatten", "Floor", "GRU", "Gather",
|
||||
"GatherElements", "Gemm", "GlobalAveragePool", "GlobalMaxPool", "Greater", "Identity",
|
||||
"If", "ImageScaler", "InstanceNormalization", "LRN", "LSTM", "LeakyRelu", "Less",
|
||||
"LessOrEqual", "Log", "LogSoftmax", "Loop", "MatMul", "Max", "MaxPool", "Min", "Mul",
|
||||
"Multinomial", "Neg", "NonZero", "Not", "NonMaxSuppression", "OneHot", "Or", "Pad", "Pow",
|
||||
"PRelu", "QuantizeLinear", "RNN", "RandomNormal", "RandomNormalLike", "RandomUniform",
|
||||
"RandomUniformLike", "Range", "Reciprocal", "ReduceL1", "ReduceL2", "ReduceLogSum",
|
||||
"ReduceLogSumExp", "ReduceMax", "ReduceMean", "ReduceMin", "ReduceProd", "ReduceSum",
|
||||
"ReduceSumSquare", "Relu", "Reshape", "Resize", "Roialign", "Round", "Scatter", "Selu",
|
||||
"Shape", "Sigmoid", "Sign", "Sin", "Sinh", "Slice", "Softmax", "SpaceToDepth", "Split",
|
||||
"Sqrt", "Squeeze", "Sub", "Sum", "Tan", "Tanh", "Tile", "TopK", "Transpose", "Unsqueeze",
|
||||
"Where", "Xor"};
|
||||
"ArgMax", "ArgMin", "Asin", "Asinh", "Atan", "Atanh", "ATen", "AveragePool",
|
||||
"BatchNormalization", "Cast", "Ceil", "Celu", "Clip", "Concat", "Constant", "ConstantFill",
|
||||
"ConstantOfShape", "Conv", "ConvInteger", "ConvTranspose", "Cos", "Cosh", "CumSum",
|
||||
"DepthToSpace", "DequantizeLinear", "Div", "Dropout", "Elu", "Equal", "Erf", "Exp",
|
||||
"Expand", "EyeLike", "Flatten", "Floor", "GRU", "Gather", "GatherElements", "Gemm", "GlobalAveragePool",
|
||||
"GlobalMaxPool", "Greater", "GreaterOrEqual", "HardSigmoid", "HardSwish", "Identity",
|
||||
"If", "ImageScaler", "InstanceNormalization", "LeakyRelu", "Less", "LessOrEqual",
|
||||
"Log", "LogSoftmax", "Loop", "LpNormalization", "LRN", "LSTM", "MatMul", "MatMulInteger", "Max", "MaxPool",
|
||||
"Mean", "Min", "Mul", "Multinomial", "Neg", "NonMaxSuppression", "NonZero", "Not",
|
||||
"OneHot", "Or", "Pad", "Pow", "PRelu", "QuantizeLinear", "RandomNormal", "RandomNormalLike",
|
||||
"RandomUniform", "RandomUniformLike", "Range", "Reciprocal", "ReduceL1", "ReduceL2",
|
||||
"ReduceLogSum", "ReduceLogSumExp", "ReduceMax", "ReduceMean", "ReduceMin", "ReduceProd",
|
||||
"ReduceSum", "ReduceSumSquare", "Relu", "Reshape", "Resize", "RNN", "Roialign", "Round",
|
||||
"Scatter", "ScatterElements", "ScatterND", "Selu", "Shape", "Sigmoid", "Sign", "Sin", "Sinh", "Slice", "Softmax", "Softplus",
|
||||
"Softsign", "SpaceToDepth", "Split", "Sqrt", "Squeeze", "Sub", "Sum", "Tan", "Tanh",
|
||||
"ThresholdedRelu", "Tile", "TopK", "Transpose", "Unsqueeze", "Upsample", "Where", "Xor"};
|
||||
std::vector<NodeIndex> unsupported_nodes_idx;
|
||||
for (const auto& node_idx : graph_viewer.GetNodesInTopologicalOrder()) {
|
||||
if (IsNodeSupported(mgx_supported_ops, graph_viewer, node_idx, logger)) {
|
||||
|
|
@ -1061,10 +922,17 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v
|
|||
auto model_proto = model->ToProto();
|
||||
ToGraphProtoInternal(graph_viewer, *model_proto->mutable_graph());
|
||||
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
|
||||
|
||||
std::string onnx_string_buffer;
|
||||
model_proto->SerializeToString(onnx_string_buffer);
|
||||
|
||||
// dump onnx file if environment var is set
|
||||
if (dump_model_ops_) {
|
||||
std::string model_name = graph_viewer.Name() + ".onnx";
|
||||
std::ofstream ofs(model_name);
|
||||
ofs.write(onnx_string_buffer.c_str(), onnx_string_buffer.size());
|
||||
ofs.close();
|
||||
}
|
||||
|
||||
// This is a list of initializers that migraphx considers as constants.
|
||||
// Example weights, reshape shape etc.
|
||||
std::unordered_set<std::string> mgx_required_initializers;
|
||||
|
|
@ -1076,6 +944,15 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v
|
|||
auto sub_graph = GetSubGraph(node_indices, graph_viewer);
|
||||
result.push_back(ComputeCapability::Create(std::move(sub_graph)));
|
||||
} else { // unsupported_nodes_idx.empty()
|
||||
if (dump_model_ops_) {
|
||||
LOGS_DEFAULT(INFO) << "============= Unsupported nodes ====================" << std::endl;
|
||||
for (auto idx : unsupported_nodes)
|
||||
{
|
||||
LOGS_DEFAULT(INFO) << graph_viewer.GetNode(idx)->OpType() << std::endl;
|
||||
}
|
||||
LOGS_DEFAULT(INFO) << "************* Unsupported nodes ********************" << std::endl;
|
||||
}
|
||||
|
||||
if (unsupported_nodes.size() > 10)
|
||||
{
|
||||
return result;
|
||||
|
|
@ -1170,6 +1047,13 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<onnxruntime::Node*>&
|
|||
std::string onnx_string_buffer;
|
||||
model_proto->SerializeToString(onnx_string_buffer);
|
||||
|
||||
if (dump_model_ops_) {
|
||||
std::string onnx_name = fused_node->Name() + ".onnx";
|
||||
std::ofstream ofs(onnx_name);
|
||||
ofs.write(onnx_string_buffer.data(), onnx_string_buffer.size());
|
||||
ofs.close();
|
||||
}
|
||||
|
||||
std::vector<std::string> input_names, output_names;
|
||||
no_input_shape = no_input_shape or get_input_output_names(*graph_body_viewer, input_names, output_names);
|
||||
|
||||
|
|
@ -1202,7 +1086,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<onnxruntime::Node*>&
|
|||
std::unique_ptr<MIGraphXFuncState> p = std::make_unique<MIGraphXFuncState>();
|
||||
*p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name],
|
||||
map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_,
|
||||
map_no_input_shape_[context->node_name], fp16_enable_};
|
||||
map_no_input_shape_[context->node_name], fp16_enable_, dump_model_ops_};
|
||||
*state = p.release();
|
||||
return 0;
|
||||
};
|
||||
|
|
@ -1296,7 +1180,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<onnxruntime::Node*>&
|
|||
ort.ReleaseTensorTypeAndShapeInfo(tensor_info);
|
||||
|
||||
migraphx_shape_datatype_t mgx_type;
|
||||
get_migraphx_type(tensor_type, mgx_type);
|
||||
getMIGraphXType(tensor_type, mgx_type);
|
||||
auto mgx_s = param_shapes[name];
|
||||
|
||||
if (mgx_type != mgx_s.type()) {
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ namespace onnxruntime {
|
|||
|
||||
namespace migraphx_env_vars {
|
||||
static const std::string kFP16Enable = "ORT_MIGRAPHX_FP16_ENABLE";
|
||||
static const std::string dumpModelOps = "ORT_MIGRAPHX_DUMP_MODEL_OPS";
|
||||
};
|
||||
|
||||
// Information to construct kernel function state.
|
||||
|
|
@ -31,6 +32,7 @@ struct MIGraphXFuncState {
|
|||
OrtMutex* mgx_mu_ptr = nullptr;
|
||||
bool no_input_shape = false;
|
||||
bool fp16_enable = false;
|
||||
bool dump_model_ops = false;
|
||||
};
|
||||
|
||||
// Logical device representation.
|
||||
|
|
@ -58,6 +60,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
|
|||
|
||||
private:
|
||||
bool fp16_enable_ = false;
|
||||
bool dump_model_ops_ = false;
|
||||
int device_id_;
|
||||
migraphx::target t_;
|
||||
OrtMutex mgx_mu_;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,164 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License
|
||||
|
||||
#pragma once
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include "core/framework/allocatormgr.h"
|
||||
#include "core/framework/execution_provider.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
bool IsGraphInput(const GraphViewer& graph, const std::string& name)
|
||||
{
|
||||
const auto& graph_inputs = graph.GetInputs();
|
||||
std::vector<std::string> input_names(graph_inputs.size());
|
||||
std::transform(graph_inputs.begin(), graph_inputs.end(), input_names.begin(), [](auto in) {
|
||||
return in->Name();
|
||||
});
|
||||
return (std::find(input_names.begin(), input_names.end(), name) != input_names.end());
|
||||
}
|
||||
|
||||
bool IsGraphInitializer(const GraphViewer& graph, const std::string& name, bool check_outer_scope = true) {
|
||||
const ONNX_NAMESPACE::TensorProto* initializer = nullptr;
|
||||
return graph.GetInitializedTensor(name, initializer);
|
||||
}
|
||||
|
||||
const Node* GetInputNode(const Node& node, int arg_index) {
|
||||
int index = 0;
|
||||
for (auto nit = node.InputNodesBegin(); nit != node.InputNodesEnd(); ++nit, ++index)
|
||||
{
|
||||
if (index == arg_index)
|
||||
{
|
||||
return &(*nit);
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::size_t getNodeInputNum(const Node& node)
|
||||
{
|
||||
std::size_t node_num = 0;
|
||||
for(auto it = node.InputNodesBegin(); it != node.InputNodesEnd(); ++it)
|
||||
{
|
||||
node_num++;
|
||||
}
|
||||
|
||||
return node_num;
|
||||
}
|
||||
|
||||
bool isInputNode(const Node* node, const std::string& name)
|
||||
{
|
||||
auto outputs = node->OutputDefs();
|
||||
return std::any_of(outputs.begin(), outputs.end(), [&](auto out) {
|
||||
return (out->Name() == name);
|
||||
});
|
||||
}
|
||||
|
||||
bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector<NodeIndex>& input_nodes)
|
||||
{
|
||||
if (node == nullptr)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<const Node*> in_nodes;
|
||||
for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit)
|
||||
{
|
||||
in_nodes.push_back(&(*nit));
|
||||
}
|
||||
|
||||
if (node->OpType() == "Shape")
|
||||
{
|
||||
input_nodes.push_back(node->Index());
|
||||
return true;
|
||||
}
|
||||
|
||||
auto inputs = node->InputDefs();
|
||||
for (std::size_t i = 0; i < inputs.size(); ++i)
|
||||
{
|
||||
const std::string& input_name = inputs.at(i)->Name();
|
||||
// If it is an initializer, it can be constant folded
|
||||
if (IsGraphInitializer(graph, input_name))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Input for sure cannot be constant folded
|
||||
if (IsGraphInput(graph, input_name))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// find the node corresponding to the name
|
||||
auto nit = std::find_if(in_nodes.begin(), in_nodes.end(), [&](auto n) {
|
||||
return isInputNode(n, input_name);
|
||||
});
|
||||
if (nit == in_nodes.end())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
auto input_node = (*nit);
|
||||
// shape node, it is OK
|
||||
if (input_node->OpType() == "Shape")
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (canEvalShapeGeneral(graph, input_node, input_nodes))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
input_nodes.push_back(node->Index());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool canEvalNodeArgument(const GraphViewer& graph, const Node* node, std::vector<std::size_t> indices, std::vector<NodeIndex>& input_nodes)
|
||||
{
|
||||
input_nodes.clear();
|
||||
std::vector<const Node*> in_nodes;
|
||||
for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit)
|
||||
{
|
||||
in_nodes.push_back(&(*nit));
|
||||
}
|
||||
|
||||
auto inputs = node->InputDefs();
|
||||
for (auto index : indices)
|
||||
{
|
||||
// an initializer itself is a constant
|
||||
auto input_name = inputs.at(index)->Name();
|
||||
if (IsGraphInitializer(graph, input_name))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Input cannot be constant folded
|
||||
if (IsGraphInput(graph, input_name))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// find the node corresponding to the name
|
||||
auto nit = std::find_if(in_nodes.begin(), in_nodes.end(), [&](auto n) {
|
||||
return isInputNode(n, input_name);
|
||||
});
|
||||
if (nit == in_nodes.end())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!canEvalShapeGeneral(graph, *nit, input_nodes))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
192
onnxruntime/test/providers/migraphx/migraphx_basic_test.cc
Normal file
192
onnxruntime/test/providers/migraphx/migraphx_basic_test.cc
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
#include "core/session/inference_session.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/framework/test_utils.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/util/include/default_providers.h"
|
||||
#include "test/util/include/scoped_env_vars.h"
|
||||
#include "core/providers/migraphx/migraphx_execution_provider_utils.h"
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
using namespace std;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace ::onnxruntime::logging;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace test {
|
||||
|
||||
template <typename T>
|
||||
void VerifyOutputs(const std::vector<OrtValue>& fetches, const std::vector<int64_t>& expected_dims,
|
||||
const std::vector<T>& expected_values) {
|
||||
ASSERT_EQ(1, fetches.size());
|
||||
auto& rtensor = fetches.front().Get<Tensor>();
|
||||
TensorShape expected_shape(expected_dims);
|
||||
ASSERT_EQ(expected_shape, rtensor.Shape());
|
||||
const std::vector<T> found(rtensor.template Data<T>(), rtensor.template Data<T>() + expected_values.size());
|
||||
ASSERT_EQ(expected_values, found);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a simple model with two inputs and one initializer.
|
||||
* input: "X", "Y" and "Z"
|
||||
* output: "M"
|
||||
*
|
||||
* "X" "Y"
|
||||
* \ /
|
||||
* "Ini" Add
|
||||
* | \ /
|
||||
* | Add
|
||||
* | \
|
||||
* \ Shape
|
||||
* \ /
|
||||
* Reshape
|
||||
* |
|
||||
* M
|
||||
*/
|
||||
void CreateBaseModel(onnxruntime::Model& model, std::vector<int> dims) {
|
||||
auto& graph = model.MainGraph();
|
||||
std::vector<onnxruntime::NodeArg*> inputs;
|
||||
std::vector<onnxruntime::NodeArg*> outputs;
|
||||
|
||||
// FLOAT tensor
|
||||
ONNX_NAMESPACE::TypeProto float_tensor;
|
||||
float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
|
||||
for (auto dim: dims) {
|
||||
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim);
|
||||
}
|
||||
|
||||
// INT tensor
|
||||
ONNX_NAMESPACE::TypeProto int64_tensor;
|
||||
int64_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
int64_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3);
|
||||
|
||||
// constant
|
||||
TensorProto value_tensor;
|
||||
value_tensor.add_dims(1);
|
||||
value_tensor.add_float_data(1.f);
|
||||
value_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
value_tensor.set_name("Ini");
|
||||
graph.AddInitializedTensor(value_tensor);
|
||||
|
||||
// Create node1 (Add)
|
||||
auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor);
|
||||
auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor);
|
||||
inputs.push_back(&input_arg_1);
|
||||
inputs.push_back(&input_arg_2);
|
||||
auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", &float_tensor);
|
||||
outputs.push_back(&output_arg);
|
||||
graph.AddNode("node_1", "Add", "node 1.", inputs, outputs);
|
||||
|
||||
// Create node2 (Add)
|
||||
auto& input_arg_3 = graph.GetOrCreateNodeArg("Ini", &float_tensor);
|
||||
inputs.clear();
|
||||
inputs.push_back(&output_arg);
|
||||
inputs.push_back(&input_arg_3);
|
||||
auto& output_arg_2 = graph.GetOrCreateNodeArg("M", &float_tensor);
|
||||
outputs.clear();
|
||||
outputs.push_back(&output_arg_2);
|
||||
graph.AddNode("node_2", "Add", "node 2.", inputs, outputs);
|
||||
|
||||
// Create node3 (Shape)
|
||||
inputs.clear();
|
||||
outputs.clear();
|
||||
inputs.push_back(&output_arg_2);
|
||||
auto& output_arg_3 = graph.GetOrCreateNodeArg("S", &int64_tensor);
|
||||
outputs.push_back(&output_arg_3);
|
||||
graph.AddNode("node_3", "Shape", "node 3.", inputs, outputs);
|
||||
|
||||
// Create node4 (Reshape)
|
||||
inputs.clear();
|
||||
outputs.clear();
|
||||
inputs.push_back(&input_arg_3);
|
||||
inputs.push_back(&output_arg_3);
|
||||
auto& output_arg_4 = graph.GetOrCreateNodeArg("R", &float_tensor);
|
||||
outputs.push_back(&output_arg_4);
|
||||
graph.AddNode("node_4", "Reshape", "node 4.", inputs, outputs);
|
||||
|
||||
auto status = graph.Resolve();
|
||||
ASSERT_TRUE(status.IsOK());
|
||||
}
|
||||
|
||||
TEST(MIGraphXExecutionProviderTest, GraphInputName) {
|
||||
std::string graph_name = "migraphx_util_test";
|
||||
onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger());
|
||||
std::vector<int> dims = {1, 3, 2};
|
||||
|
||||
CreateBaseModel(model, dims);
|
||||
|
||||
auto& graph = model.MainGraph();
|
||||
GraphViewer gv(graph);
|
||||
|
||||
ASSERT_EQ(IsGraphInput(gv, "X"), true);
|
||||
}
|
||||
|
||||
TEST(MIGraphXExecutionProviderTest, GraphInitializer) {
|
||||
std::string graph_name = "migraphx_util_test";
|
||||
onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger());
|
||||
std::vector<int> dims = {1, 3, 2};
|
||||
|
||||
CreateBaseModel(model, dims);
|
||||
|
||||
auto& graph = model.MainGraph();
|
||||
GraphViewer gv(graph);
|
||||
|
||||
ASSERT_EQ(IsGraphInitializer(gv, "Ini"), true);
|
||||
}
|
||||
|
||||
TEST(MIGraphXExecutionProviderTest, NodeInputNum) {
|
||||
std::string graph_name = "migraphx_util_test";
|
||||
onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger());
|
||||
std::vector<int> dims = {1, 3, 2};
|
||||
|
||||
CreateBaseModel(model, dims);
|
||||
|
||||
auto& graph = model.MainGraph();
|
||||
GraphViewer gv(graph);
|
||||
|
||||
// get the first add node
|
||||
const auto& node0 = gv.GetNode(0);
|
||||
const auto& node1 = gv.GetNode(1);
|
||||
|
||||
ASSERT_EQ(getNodeInputNum(*node0), 0);
|
||||
ASSERT_EQ(getNodeInputNum(*node1), 1);
|
||||
}
|
||||
|
||||
TEST(MIGraphXExecutionProviderTest, IsNodeInput) {
|
||||
std::string graph_name = "migraphx_util_test";
|
||||
onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger());
|
||||
std::vector<int> dims = {1, 3, 2};
|
||||
|
||||
CreateBaseModel(model, dims);
|
||||
|
||||
auto& graph = model.MainGraph();
|
||||
GraphViewer gv(graph);
|
||||
|
||||
// get the first add node
|
||||
const auto& node2 = gv.GetNode(1);
|
||||
ASSERT_EQ(isInputNode(node2, "M"), true);
|
||||
}
|
||||
|
||||
TEST(MIGraphXExecutionProviderTest, canEvalArgument) {
|
||||
std::string graph_name = "migraphx_util_test";
|
||||
onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger());
|
||||
std::vector<int> dims = {1, 3, 2};
|
||||
|
||||
CreateBaseModel(model, dims);
|
||||
|
||||
auto& graph = model.MainGraph();
|
||||
GraphViewer gv(graph);
|
||||
|
||||
// get the first add node
|
||||
const auto& node2 = gv.GetNode(3);
|
||||
std::vector<NodeIndex> input_nodes;
|
||||
ASSERT_EQ(canEvalNodeArgument(gv, node2, {1}, input_nodes), true);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue