amdmigraphx_ep-add ops to be supported by migraphx and fixed a bug in check ops to be supported (#10496)

* backup debugging information related to debugging a jira ticket

* fixed a bug in checking whether an input can be constand folded

* added more operators that are supported by migraphx

* revert unnecessary changes

* remove unused logger parameter

* rename function to make name style consistent

* backup code changes

* fix review comments

* refactor graph utility functions to add unit tests

* backup additional changes

* fixed a link error in build migraphx_basic_test

* add unit test for some migraphx utility functions

* add more supported ops in migraphx
This commit is contained in:
Shucai Xiao 2022-03-23 21:17:19 -05:00 committed by GitHub
parent ae08f9666d
commit 7ee52fb8a0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 440 additions and 190 deletions

View file

@ -528,6 +528,13 @@ if(onnxruntime_USE_TENSORRT)
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_tensorrt onnxruntime_providers_shared)
endif()
if(onnxruntime_USE_MIGRAPHX)
list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/migraphx/*)
list(APPEND onnxruntime_test_framework_src_patterns "${ONNXRUNTIME_ROOT}/core/providers/migraphx/migraphx_execution_provider_utils.h")
list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_migraphx)
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx onnxruntime_providers_shared)
endif()
if(onnxruntime_USE_NNAPI_BUILTIN)
list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/nnapi/*)
list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_nnapi)

View file

@ -5,7 +5,9 @@
#define ORT_API_MANUAL_INIT
#include "core/session/onnxruntime_cxx_api.h"
#include "core/common/safeint.h"
#include "core/common/logging/severity.h"
#include "migraphx_execution_provider.h"
#include "migraphx_execution_provider_utils.h"
#include "hip_allocator.h"
#include "hip_fence.h"
#include "gpu_data_transfer.h"
@ -112,6 +114,12 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
if (!fp16_enable_env.empty()) {
fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true);
}
// dump unsupported ops
const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::dumpModelOps);
if (!dump_model_ops_env.empty()) {
dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true);
}
}
AllocatorPtr MIGraphXExecutionProvider::GetAllocator(int id, OrtMemType mem_type) const {
@ -196,7 +204,7 @@ static bool IsTypeSupported(const NodeArg* node_arg) {
}
}
static bool get_migraphx_type(ONNXTensorElementDataType type,
static bool getMIGraphXType(ONNXTensorElementDataType type,
migraphx_shape_datatype_t& mgx_type) {
mgx_type = migraphx_shape_float_type;
switch (type) {
@ -245,35 +253,8 @@ static bool get_migraphx_type(ONNXTensorElementDataType type,
return true;
}
static bool IsGraphInput(const GraphViewer& graph, const std::string& name)
{
const auto& graph_inputs = graph.GetInputs();
std::vector<std::string> input_names(graph_inputs.size());
std::transform(graph_inputs.begin(), graph_inputs.end(), input_names.begin(), [](auto in) {
return in->Name();
});
return (std::find(input_names.begin(), input_names.end(), name) != input_names.end());
}
static bool IsGraphInitializer(const GraphViewer& graph, const std::string& name, bool check_outer_scope = true) {
const ONNX_NAMESPACE::TensorProto* initializer = nullptr;
return graph.GetInitializedTensor(name, initializer);
}
const Node* GetInputNode(const Node& node, int arg_index) {
int index = 0;
for (auto nit = node.InputNodesBegin(); nit != node.InputNodesEnd(); ++nit, ++index)
{
if (index == arg_index)
{
return &(*nit);
}
}
return nullptr;
}
std::vector<int> to_vector(const ONNX_NAMESPACE::int64s& nums)
std::vector<int> toVector(const ONNX_NAMESPACE::int64s& nums)
{
std::vector<int> result;
int num = nums.size();
@ -285,125 +266,7 @@ std::vector<int> to_vector(const ONNX_NAMESPACE::int64s& nums)
return result;
}
std::size_t node_input_num(const Node& node)
{
std::size_t node_num = 0;
for(auto it = node.InputNodesBegin(); it != node.InputNodesEnd(); ++it)
{
node_num++;
}
return node_num;
}
static bool can_eval_shape_general(const GraphViewer& graph, const Node* node, const logging::Logger& logger, std::vector<NodeIndex>& input_nodes)
{
if (node == nullptr)
{
return false;
}
std::vector<const Node*> in_nodes;
for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit)
{
in_nodes.push_back(&(*nit));
}
if (node->OpType() == "Shape")
{
input_nodes.push_back(node->Index());
return true;
}
auto inputs = node->InputDefs();
for (std::size_t i = 0; i < inputs.size(); ++i)
{
const std::string& input_name = inputs.at(i)->Name();
// If it is an initializer, it can be constant folded
if (IsGraphInitializer(graph, input_name))
{
continue;
}
// Input for sure cannot be constant folded
if (IsGraphInput(graph, input_name))
{
return false;
}
// find the node corresponding to the name
auto nit = std::find_if(in_nodes.begin(), in_nodes.end(), [&](auto n) {
return input_name.find(n->Name()) != std::string::npos;
});
if (nit == in_nodes.end())
{
return false;
}
auto input_node = (*nit);
// shape node, it is OK
if (input_node->OpType() == "Shape")
{
continue;
}
if (can_eval_shape_general(graph, input_node, logger, input_nodes))
{
continue;
}
return false;
}
input_nodes.push_back(node->Index());
return true;
}
static bool can_eval_node_argument(const GraphViewer& graph, const Node* node, std::vector<std::size_t> indices, const logging::Logger& logger, std::vector<NodeIndex>& input_nodes)
{
input_nodes.clear();
std::vector<const Node*> in_nodes;
for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit)
{
in_nodes.push_back(&(*nit));
}
auto inputs = node->InputDefs();
for (auto index : indices)
{
// an initializer itself is a constant
auto input_name = inputs.at(index)->Name();
if (IsGraphInitializer(graph, input_name))
{
continue;
}
// Input cannot be constant folded
if (IsGraphInput(graph, input_name))
{
return false;
}
// find the node corresponding to the name
auto nit = std::find_if(in_nodes.begin(), in_nodes.end(), [&](auto n) {
return input_name.find(n->Name()) != std::string::npos;
});
if (nit == in_nodes.end())
{
return false;
}
if (!can_eval_shape_general(graph, *nit, logger, input_nodes))
{
return false;
}
}
return true;
}
static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, const Node* node, const logging::Logger& logger) {
static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, const Node* node) {
std::vector<NodeIndex> input_nodes;
const auto& optype = node->OpType();
if (optype == "ArgMax" or optype == "ArgMin") {
@ -414,7 +277,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
return true;
}
} else if (optype == "ConstantOfShape") {
if (!can_eval_node_argument(graph_viewer, node, {0}, logger, input_nodes))
if (!canEvalNodeArgument(graph_viewer, node, {0}, input_nodes))
{
return true;
}
@ -439,7 +302,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
}
} else if (optype == "Expand") {
// MIGraphX only supports constant shape input values
if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes))
{
return true;
}
@ -454,7 +317,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
const auto& attributes = node->GetAttributes();
auto dila_attr = attributes.find("dilations");
if (dila_attr != attributes.end()) {
auto dilas = to_vector((*dila_attr).second.ints());
auto dilas = toVector((*dila_attr).second.ints());
bool ret = std::all_of(dilas.begin(), dilas.end(), [](auto i) { return i == 1; });
if (ret == false) {
return true;
@ -493,12 +356,12 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
return true;
}
} else if (optype == "NonZero") {
if (!can_eval_node_argument(graph_viewer, node, {0}, logger, input_nodes))
if (!canEvalNodeArgument(graph_viewer, node, {0}, input_nodes))
{
return true;
}
} else if (optype == "OneHot") {
if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes))
{
return true;
}
@ -506,7 +369,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
const auto& args = node->InputDefs();
// if pad size is not constant, migraphx cannot support
if (args.size() >= 2) {
if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes))
{
return true;
}
@ -527,7 +390,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
// input value only applied to constant mode
if (mode == "constant") {
if (args.size() == 3) {
if (!can_eval_node_argument(graph_viewer, node, {2}, logger, input_nodes))
if (!canEvalNodeArgument(graph_viewer, node, {2}, input_nodes))
{
return true;
}
@ -537,20 +400,20 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
auto arg_num = node->InputDefs().size();
std::vector<std::size_t> vec(arg_num);
std::iota(vec.begin(), vec.end(), 0);
if (!can_eval_node_argument(graph_viewer, node, vec, logger, input_nodes))
if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes))
{
return true;
}
} else if (optype == "Reshape") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes))
{
return false;
}
return true;
}
} else if (optype == "Resize") {
} else if (optype == "Resize" or optype == "Upsample") {
const auto& attributes = node->GetAttributes();
auto ct_attr = attributes.find("coordinate_transformation_mode");
if (ct_attr != attributes.end()) {
@ -575,7 +438,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
{
std::vector<std::size_t> indices(args.size() - 1);
std::iota(indices.begin(), indices.end(), 1);
if (can_eval_node_argument(graph_viewer, node, indices, logger, input_nodes))
if (canEvalNodeArgument(graph_viewer, node, indices, input_nodes))
{
return false;
}
@ -584,7 +447,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
} else if (optype == "ReduceSum") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes))
{
return false;
}
@ -598,15 +461,15 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
std::vector<std::size_t> vec(arg_num);
std::iota(vec.begin(), vec.end(), 0);
vec.erase(vec.begin());
if (!can_eval_node_argument(graph_viewer, node, vec, logger, input_nodes))
if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes))
{
return true;
}
const auto& attributes = node->GetAttributes();
if (attributes.count("starts") > 0 and attributes.count("ends") > 0) {
auto starts = to_vector((*attributes.find("starts")).second.ints());
auto ends = to_vector((*attributes.find("ends")).second.ints());
auto starts = toVector((*attributes.find("starts")).second.ints());
auto ends = toVector((*attributes.find("ends")).second.ints());
for (std::size_t i = 0; i < starts.size(); ++i) {
if (starts.at(i) > ends.at(i)) {
return true;
@ -636,26 +499,26 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes))
{
return false;
}
return true;
}
} else if (optype == "Tile") {
if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes))
{
return true;
}
} else if (optype == "TopK") {
if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes))
{
return true;
}
} else if (optype == "Unsqueeze" or optype == "Squeeze") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes))
{
return false;
}
@ -690,8 +553,7 @@ void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::v
const auto& args = node->InputDefs();
if (args.size() == 2) {
std::vector<NodeIndex> node_inputs;
// if (can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, node_inputs))
if (can_eval_node_argument(graph_viewer, node, {1}, logger, node_inputs))
if (canEvalNodeArgument(graph_viewer, node, {1}, node_inputs))
{
return (not std::all_of(node_inputs.begin(), node_inputs.end(), [&](auto index) {
return std::find(git.begin(), git.end(), index) != git.end();
@ -796,7 +658,7 @@ static bool IsNodeSupported(const std::set<std::string>& op_set,
}
// check that some modes might not be supported in migraphx for some operators
if (domain == kOnnxDomain && IsUnsupportedOpMode(graph_viewer, node, logger)) {
if (domain == kOnnxDomain && IsUnsupportedOpMode(graph_viewer, node)) {
// not supported, then check the constant folding capability of migraphx
// to see whether it is supported
return false;
@ -959,8 +821,6 @@ std::unique_ptr<IndexedSubGraph> MIGraphXExecutionProvider::GetSubGraph(const st
output_names.push_back(name);
}
// Generate unique kernel name for MIGraphX subgraph
uint64_t model_hash = 0;
int id = GenerateMetaDefId(graph, model_hash);
@ -992,21 +852,22 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer,
/*out*/ std::unordered_set<std::string>& mgx_required_initializers,
const logging::Logger& logger) {
static std::set<std::string> mgx_supported_ops = {"Abs", "Acos", "Acosh", "Add", "And",
"ArgMax", "ArgMin", "Asin", "Asinh", "Atan", "Atanh", "AveragePool",
"BatchNormalization", "Cast", "Ceil", "Clip", "Concat", "Constant", "ConstantFill",
"ConstantOfShape", "Conv", "Cos", "Cosh", "DepthToSpace", "DequantizeLinear", "Div",
"Dropout", "Elu", "Equal", "Erf", "Exp", "Expand", "Flatten", "Floor", "GRU", "Gather",
"GatherElements", "Gemm", "GlobalAveragePool", "GlobalMaxPool", "Greater", "Identity",
"If", "ImageScaler", "InstanceNormalization", "LRN", "LSTM", "LeakyRelu", "Less",
"LessOrEqual", "Log", "LogSoftmax", "Loop", "MatMul", "Max", "MaxPool", "Min", "Mul",
"Multinomial", "Neg", "NonZero", "Not", "NonMaxSuppression", "OneHot", "Or", "Pad", "Pow",
"PRelu", "QuantizeLinear", "RNN", "RandomNormal", "RandomNormalLike", "RandomUniform",
"RandomUniformLike", "Range", "Reciprocal", "ReduceL1", "ReduceL2", "ReduceLogSum",
"ReduceLogSumExp", "ReduceMax", "ReduceMean", "ReduceMin", "ReduceProd", "ReduceSum",
"ReduceSumSquare", "Relu", "Reshape", "Resize", "Roialign", "Round", "Scatter", "Selu",
"Shape", "Sigmoid", "Sign", "Sin", "Sinh", "Slice", "Softmax", "SpaceToDepth", "Split",
"Sqrt", "Squeeze", "Sub", "Sum", "Tan", "Tanh", "Tile", "TopK", "Transpose", "Unsqueeze",
"Where", "Xor"};
"ArgMax", "ArgMin", "Asin", "Asinh", "Atan", "Atanh", "ATen", "AveragePool",
"BatchNormalization", "Cast", "Ceil", "Celu", "Clip", "Concat", "Constant", "ConstantFill",
"ConstantOfShape", "Conv", "ConvInteger", "ConvTranspose", "Cos", "Cosh", "CumSum",
"DepthToSpace", "DequantizeLinear", "Div", "Dropout", "Elu", "Equal", "Erf", "Exp",
"Expand", "EyeLike", "Flatten", "Floor", "GRU", "Gather", "GatherElements", "Gemm", "GlobalAveragePool",
"GlobalMaxPool", "Greater", "GreaterOrEqual", "HardSigmoid", "HardSwish", "Identity",
"If", "ImageScaler", "InstanceNormalization", "LeakyRelu", "Less", "LessOrEqual",
"Log", "LogSoftmax", "Loop", "LpNormalization", "LRN", "LSTM", "MatMul", "MatMulInteger", "Max", "MaxPool",
"Mean", "Min", "Mul", "Multinomial", "Neg", "NonMaxSuppression", "NonZero", "Not",
"OneHot", "Or", "Pad", "Pow", "PRelu", "QuantizeLinear", "RandomNormal", "RandomNormalLike",
"RandomUniform", "RandomUniformLike", "Range", "Reciprocal", "ReduceL1", "ReduceL2",
"ReduceLogSum", "ReduceLogSumExp", "ReduceMax", "ReduceMean", "ReduceMin", "ReduceProd",
"ReduceSum", "ReduceSumSquare", "Relu", "Reshape", "Resize", "RNN", "Roialign", "Round",
"Scatter", "ScatterElements", "ScatterND", "Selu", "Shape", "Sigmoid", "Sign", "Sin", "Sinh", "Slice", "Softmax", "Softplus",
"Softsign", "SpaceToDepth", "Split", "Sqrt", "Squeeze", "Sub", "Sum", "Tan", "Tanh",
"ThresholdedRelu", "Tile", "TopK", "Transpose", "Unsqueeze", "Upsample", "Where", "Xor"};
std::vector<NodeIndex> unsupported_nodes_idx;
for (const auto& node_idx : graph_viewer.GetNodesInTopologicalOrder()) {
if (IsNodeSupported(mgx_supported_ops, graph_viewer, node_idx, logger)) {
@ -1061,10 +922,17 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v
auto model_proto = model->ToProto();
ToGraphProtoInternal(graph_viewer, *model_proto->mutable_graph());
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
std::string onnx_string_buffer;
model_proto->SerializeToString(onnx_string_buffer);
// dump onnx file if environment var is set
if (dump_model_ops_) {
std::string model_name = graph_viewer.Name() + ".onnx";
std::ofstream ofs(model_name);
ofs.write(onnx_string_buffer.c_str(), onnx_string_buffer.size());
ofs.close();
}
// This is a list of initializers that migraphx considers as constants.
// Example weights, reshape shape etc.
std::unordered_set<std::string> mgx_required_initializers;
@ -1076,6 +944,15 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v
auto sub_graph = GetSubGraph(node_indices, graph_viewer);
result.push_back(ComputeCapability::Create(std::move(sub_graph)));
} else { // unsupported_nodes_idx.empty()
if (dump_model_ops_) {
LOGS_DEFAULT(INFO) << "============= Unsupported nodes ====================" << std::endl;
for (auto idx : unsupported_nodes)
{
LOGS_DEFAULT(INFO) << graph_viewer.GetNode(idx)->OpType() << std::endl;
}
LOGS_DEFAULT(INFO) << "************* Unsupported nodes ********************" << std::endl;
}
if (unsupported_nodes.size() > 10)
{
return result;
@ -1170,6 +1047,13 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<onnxruntime::Node*>&
std::string onnx_string_buffer;
model_proto->SerializeToString(onnx_string_buffer);
if (dump_model_ops_) {
std::string onnx_name = fused_node->Name() + ".onnx";
std::ofstream ofs(onnx_name);
ofs.write(onnx_string_buffer.data(), onnx_string_buffer.size());
ofs.close();
}
std::vector<std::string> input_names, output_names;
no_input_shape = no_input_shape or get_input_output_names(*graph_body_viewer, input_names, output_names);
@ -1202,7 +1086,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<onnxruntime::Node*>&
std::unique_ptr<MIGraphXFuncState> p = std::make_unique<MIGraphXFuncState>();
*p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name],
map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_,
map_no_input_shape_[context->node_name], fp16_enable_};
map_no_input_shape_[context->node_name], fp16_enable_, dump_model_ops_};
*state = p.release();
return 0;
};
@ -1296,7 +1180,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<onnxruntime::Node*>&
ort.ReleaseTensorTypeAndShapeInfo(tensor_info);
migraphx_shape_datatype_t mgx_type;
get_migraphx_type(tensor_type, mgx_type);
getMIGraphXType(tensor_type, mgx_type);
auto mgx_s = param_shapes[name];
if (mgx_type != mgx_s.type()) {

View file

@ -16,6 +16,7 @@ namespace onnxruntime {
namespace migraphx_env_vars {
static const std::string kFP16Enable = "ORT_MIGRAPHX_FP16_ENABLE";
static const std::string dumpModelOps = "ORT_MIGRAPHX_DUMP_MODEL_OPS";
};
// Information to construct kernel function state.
@ -31,6 +32,7 @@ struct MIGraphXFuncState {
OrtMutex* mgx_mu_ptr = nullptr;
bool no_input_shape = false;
bool fp16_enable = false;
bool dump_model_ops = false;
};
// Logical device representation.
@ -58,6 +60,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
private:
bool fp16_enable_ = false;
bool dump_model_ops_ = false;
int device_id_;
migraphx::target t_;
OrtMutex mgx_mu_;

View file

@ -0,0 +1,164 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License
#pragma once
#include "core/session/onnxruntime_cxx_api.h"
#include "core/framework/allocatormgr.h"
#include "core/framework/execution_provider.h"
namespace onnxruntime {
bool IsGraphInput(const GraphViewer& graph, const std::string& name)
{
const auto& graph_inputs = graph.GetInputs();
std::vector<std::string> input_names(graph_inputs.size());
std::transform(graph_inputs.begin(), graph_inputs.end(), input_names.begin(), [](auto in) {
return in->Name();
});
return (std::find(input_names.begin(), input_names.end(), name) != input_names.end());
}
bool IsGraphInitializer(const GraphViewer& graph, const std::string& name, bool check_outer_scope = true) {
const ONNX_NAMESPACE::TensorProto* initializer = nullptr;
return graph.GetInitializedTensor(name, initializer);
}
const Node* GetInputNode(const Node& node, int arg_index) {
int index = 0;
for (auto nit = node.InputNodesBegin(); nit != node.InputNodesEnd(); ++nit, ++index)
{
if (index == arg_index)
{
return &(*nit);
}
}
return nullptr;
}
std::size_t getNodeInputNum(const Node& node)
{
std::size_t node_num = 0;
for(auto it = node.InputNodesBegin(); it != node.InputNodesEnd(); ++it)
{
node_num++;
}
return node_num;
}
bool isInputNode(const Node* node, const std::string& name)
{
auto outputs = node->OutputDefs();
return std::any_of(outputs.begin(), outputs.end(), [&](auto out) {
return (out->Name() == name);
});
}
bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector<NodeIndex>& input_nodes)
{
if (node == nullptr)
{
return false;
}
std::vector<const Node*> in_nodes;
for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit)
{
in_nodes.push_back(&(*nit));
}
if (node->OpType() == "Shape")
{
input_nodes.push_back(node->Index());
return true;
}
auto inputs = node->InputDefs();
for (std::size_t i = 0; i < inputs.size(); ++i)
{
const std::string& input_name = inputs.at(i)->Name();
// If it is an initializer, it can be constant folded
if (IsGraphInitializer(graph, input_name))
{
continue;
}
// Input for sure cannot be constant folded
if (IsGraphInput(graph, input_name))
{
return false;
}
// find the node corresponding to the name
auto nit = std::find_if(in_nodes.begin(), in_nodes.end(), [&](auto n) {
return isInputNode(n, input_name);
});
if (nit == in_nodes.end())
{
return false;
}
auto input_node = (*nit);
// shape node, it is OK
if (input_node->OpType() == "Shape")
{
continue;
}
if (canEvalShapeGeneral(graph, input_node, input_nodes))
{
continue;
}
return false;
}
input_nodes.push_back(node->Index());
return true;
}
bool canEvalNodeArgument(const GraphViewer& graph, const Node* node, std::vector<std::size_t> indices, std::vector<NodeIndex>& input_nodes)
{
input_nodes.clear();
std::vector<const Node*> in_nodes;
for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit)
{
in_nodes.push_back(&(*nit));
}
auto inputs = node->InputDefs();
for (auto index : indices)
{
// an initializer itself is a constant
auto input_name = inputs.at(index)->Name();
if (IsGraphInitializer(graph, input_name))
{
continue;
}
// Input cannot be constant folded
if (IsGraphInput(graph, input_name))
{
return false;
}
// find the node corresponding to the name
auto nit = std::find_if(in_nodes.begin(), in_nodes.end(), [&](auto n) {
return isInputNode(n, input_name);
});
if (nit == in_nodes.end())
{
return false;
}
if (!canEvalShapeGeneral(graph, *nit, input_nodes))
{
return false;
}
}
return true;
}
}

View file

@ -0,0 +1,192 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/graph/onnx_protobuf.h"
#include "core/session/inference_session.h"
#include "test/providers/provider_test_utils.h"
#include "test/framework/test_utils.h"
#include "gtest/gtest.h"
#include "test/util/include/default_providers.h"
#include "test/util/include/scoped_env_vars.h"
#include "core/providers/migraphx/migraphx_execution_provider_utils.h"
#include <string>
#include <thread>
using namespace std;
using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::logging;
namespace onnxruntime {
namespace test {
template <typename T>
void VerifyOutputs(const std::vector<OrtValue>& fetches, const std::vector<int64_t>& expected_dims,
const std::vector<T>& expected_values) {
ASSERT_EQ(1, fetches.size());
auto& rtensor = fetches.front().Get<Tensor>();
TensorShape expected_shape(expected_dims);
ASSERT_EQ(expected_shape, rtensor.Shape());
const std::vector<T> found(rtensor.template Data<T>(), rtensor.template Data<T>() + expected_values.size());
ASSERT_EQ(expected_values, found);
}
/**
* Create a simple model with two inputs and one initializer.
* input: "X", "Y" and "Z"
* output: "M"
*
* "X" "Y"
* \ /
* "Ini" Add
* | \ /
* | Add
* | \
* \ Shape
* \ /
* Reshape
* |
* M
*/
void CreateBaseModel(onnxruntime::Model& model, std::vector<int> dims) {
auto& graph = model.MainGraph();
std::vector<onnxruntime::NodeArg*> inputs;
std::vector<onnxruntime::NodeArg*> outputs;
// FLOAT tensor
ONNX_NAMESPACE::TypeProto float_tensor;
float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
for (auto dim: dims) {
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim);
}
// INT tensor
ONNX_NAMESPACE::TypeProto int64_tensor;
int64_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
int64_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3);
// constant
TensorProto value_tensor;
value_tensor.add_dims(1);
value_tensor.add_float_data(1.f);
value_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
value_tensor.set_name("Ini");
graph.AddInitializedTensor(value_tensor);
// Create node1 (Add)
auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor);
auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor);
inputs.push_back(&input_arg_1);
inputs.push_back(&input_arg_2);
auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", &float_tensor);
outputs.push_back(&output_arg);
graph.AddNode("node_1", "Add", "node 1.", inputs, outputs);
// Create node2 (Add)
auto& input_arg_3 = graph.GetOrCreateNodeArg("Ini", &float_tensor);
inputs.clear();
inputs.push_back(&output_arg);
inputs.push_back(&input_arg_3);
auto& output_arg_2 = graph.GetOrCreateNodeArg("M", &float_tensor);
outputs.clear();
outputs.push_back(&output_arg_2);
graph.AddNode("node_2", "Add", "node 2.", inputs, outputs);
// Create node3 (Shape)
inputs.clear();
outputs.clear();
inputs.push_back(&output_arg_2);
auto& output_arg_3 = graph.GetOrCreateNodeArg("S", &int64_tensor);
outputs.push_back(&output_arg_3);
graph.AddNode("node_3", "Shape", "node 3.", inputs, outputs);
// Create node4 (Reshape)
inputs.clear();
outputs.clear();
inputs.push_back(&input_arg_3);
inputs.push_back(&output_arg_3);
auto& output_arg_4 = graph.GetOrCreateNodeArg("R", &float_tensor);
outputs.push_back(&output_arg_4);
graph.AddNode("node_4", "Reshape", "node 4.", inputs, outputs);
auto status = graph.Resolve();
ASSERT_TRUE(status.IsOK());
}
TEST(MIGraphXExecutionProviderTest, GraphInputName) {
std::string graph_name = "migraphx_util_test";
onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger());
std::vector<int> dims = {1, 3, 2};
CreateBaseModel(model, dims);
auto& graph = model.MainGraph();
GraphViewer gv(graph);
ASSERT_EQ(IsGraphInput(gv, "X"), true);
}
TEST(MIGraphXExecutionProviderTest, GraphInitializer) {
std::string graph_name = "migraphx_util_test";
onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger());
std::vector<int> dims = {1, 3, 2};
CreateBaseModel(model, dims);
auto& graph = model.MainGraph();
GraphViewer gv(graph);
ASSERT_EQ(IsGraphInitializer(gv, "Ini"), true);
}
TEST(MIGraphXExecutionProviderTest, NodeInputNum) {
std::string graph_name = "migraphx_util_test";
onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger());
std::vector<int> dims = {1, 3, 2};
CreateBaseModel(model, dims);
auto& graph = model.MainGraph();
GraphViewer gv(graph);
// get the first add node
const auto& node0 = gv.GetNode(0);
const auto& node1 = gv.GetNode(1);
ASSERT_EQ(getNodeInputNum(*node0), 0);
ASSERT_EQ(getNodeInputNum(*node1), 1);
}
TEST(MIGraphXExecutionProviderTest, IsNodeInput) {
std::string graph_name = "migraphx_util_test";
onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger());
std::vector<int> dims = {1, 3, 2};
CreateBaseModel(model, dims);
auto& graph = model.MainGraph();
GraphViewer gv(graph);
// get the first add node
const auto& node2 = gv.GetNode(1);
ASSERT_EQ(isInputNode(node2, "M"), true);
}
TEST(MIGraphXExecutionProviderTest, canEvalArgument) {
std::string graph_name = "migraphx_util_test";
onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger());
std::vector<int> dims = {1, 3, 2};
CreateBaseModel(model, dims);
auto& graph = model.MainGraph();
GraphViewer gv(graph);
// get the first add node
const auto& node2 = gv.GetNode(3);
std::vector<NodeIndex> input_nodes;
ASSERT_EQ(canEvalNodeArgument(gv, node2, {1}, input_nodes), true);
}
} // namespace test
} // namespace onnxruntime