diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 3de416f980..4fdb54c32d 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -528,6 +528,13 @@ if(onnxruntime_USE_TENSORRT) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_tensorrt onnxruntime_providers_shared) endif() +if(onnxruntime_USE_MIGRAPHX) + list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/migraphx/*) + list(APPEND onnxruntime_test_framework_src_patterns "${ONNXRUNTIME_ROOT}/core/providers/migraphx/migraphx_execution_provider_utils.h") + list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_migraphx) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx onnxruntime_providers_shared) +endif() + if(onnxruntime_USE_NNAPI_BUILTIN) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/nnapi/*) list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_nnapi) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index b5cd16db86..116cffcaa0 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -5,7 +5,9 @@ #define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" #include "core/common/safeint.h" +#include "core/common/logging/severity.h" #include "migraphx_execution_provider.h" +#include "migraphx_execution_provider_utils.h" #include "hip_allocator.h" #include "hip_fence.h" #include "gpu_data_transfer.h" @@ -112,6 +114,12 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv if (!fp16_enable_env.empty()) { fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); } + + // dump unsupported ops + const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::dumpModelOps); + if (!dump_model_ops_env.empty()) { + dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true); + } } AllocatorPtr MIGraphXExecutionProvider::GetAllocator(int id, OrtMemType mem_type) const { @@ -196,7 +204,7 @@ static bool IsTypeSupported(const NodeArg* node_arg) { } } -static bool get_migraphx_type(ONNXTensorElementDataType type, +static bool getMIGraphXType(ONNXTensorElementDataType type, migraphx_shape_datatype_t& mgx_type) { mgx_type = migraphx_shape_float_type; switch (type) { @@ -245,35 +253,8 @@ static bool get_migraphx_type(ONNXTensorElementDataType type, return true; } -static bool IsGraphInput(const GraphViewer& graph, const std::string& name) -{ - const auto& graph_inputs = graph.GetInputs(); - std::vector input_names(graph_inputs.size()); - std::transform(graph_inputs.begin(), graph_inputs.end(), input_names.begin(), [](auto in) { - return in->Name(); - }); - return (std::find(input_names.begin(), input_names.end(), name) != input_names.end()); -} -static bool IsGraphInitializer(const GraphViewer& graph, const std::string& name, bool check_outer_scope = true) { - const ONNX_NAMESPACE::TensorProto* initializer = nullptr; - return graph.GetInitializedTensor(name, initializer); -} - -const Node* GetInputNode(const Node& node, int arg_index) { - int index = 0; - for (auto nit = node.InputNodesBegin(); nit != node.InputNodesEnd(); ++nit, ++index) - { - if (index == arg_index) - { - return &(*nit); - } - } - - return nullptr; -} - -std::vector to_vector(const ONNX_NAMESPACE::int64s& nums) +std::vector toVector(const ONNX_NAMESPACE::int64s& nums) { std::vector result; int num = nums.size(); @@ -285,125 +266,7 @@ std::vector to_vector(const ONNX_NAMESPACE::int64s& nums) return result; } -std::size_t node_input_num(const Node& node) -{ - std::size_t node_num = 0; - for(auto it = node.InputNodesBegin(); it != node.InputNodesEnd(); ++it) - { - node_num++; - } - - return node_num; -} - -static bool can_eval_shape_general(const GraphViewer& graph, const Node* node, const logging::Logger& logger, std::vector& input_nodes) -{ - if (node == nullptr) - { - return false; - } - - std::vector in_nodes; - for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit) - { - in_nodes.push_back(&(*nit)); - } - - if (node->OpType() == "Shape") - { - input_nodes.push_back(node->Index()); - return true; - } - - auto inputs = node->InputDefs(); - for (std::size_t i = 0; i < inputs.size(); ++i) - { - const std::string& input_name = inputs.at(i)->Name(); - // If it is an initializer, it can be constant folded - if (IsGraphInitializer(graph, input_name)) - { - continue; - } - - // Input for sure cannot be constant folded - if (IsGraphInput(graph, input_name)) - { - return false; - } - - // find the node corresponding to the name - auto nit = std::find_if(in_nodes.begin(), in_nodes.end(), [&](auto n) { - return input_name.find(n->Name()) != std::string::npos; - }); - if (nit == in_nodes.end()) - { - return false; - } - - auto input_node = (*nit); - // shape node, it is OK - if (input_node->OpType() == "Shape") - { - continue; - } - - if (can_eval_shape_general(graph, input_node, logger, input_nodes)) - { - continue; - } - - return false; - } - - input_nodes.push_back(node->Index()); - return true; -} - -static bool can_eval_node_argument(const GraphViewer& graph, const Node* node, std::vector indices, const logging::Logger& logger, std::vector& input_nodes) -{ - input_nodes.clear(); - - std::vector in_nodes; - for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit) - { - in_nodes.push_back(&(*nit)); - } - - auto inputs = node->InputDefs(); - for (auto index : indices) - { - // an initializer itself is a constant - auto input_name = inputs.at(index)->Name(); - if (IsGraphInitializer(graph, input_name)) - { - continue; - } - - // Input cannot be constant folded - if (IsGraphInput(graph, input_name)) - { - return false; - } - - // find the node corresponding to the name - auto nit = std::find_if(in_nodes.begin(), in_nodes.end(), [&](auto n) { - return input_name.find(n->Name()) != std::string::npos; - }); - if (nit == in_nodes.end()) - { - return false; - } - - if (!can_eval_shape_general(graph, *nit, logger, input_nodes)) - { - return false; - } - } - - return true; -} - -static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, const Node* node, const logging::Logger& logger) { +static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, const Node* node) { std::vector input_nodes; const auto& optype = node->OpType(); if (optype == "ArgMax" or optype == "ArgMin") { @@ -414,7 +277,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return true; } } else if (optype == "ConstantOfShape") { - if (!can_eval_node_argument(graph_viewer, node, {0}, logger, input_nodes)) + if (!canEvalNodeArgument(graph_viewer, node, {0}, input_nodes)) { return true; } @@ -439,7 +302,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co } } else if (optype == "Expand") { // MIGraphX only supports constant shape input values - if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes)) + if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { return true; } @@ -454,7 +317,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co const auto& attributes = node->GetAttributes(); auto dila_attr = attributes.find("dilations"); if (dila_attr != attributes.end()) { - auto dilas = to_vector((*dila_attr).second.ints()); + auto dilas = toVector((*dila_attr).second.ints()); bool ret = std::all_of(dilas.begin(), dilas.end(), [](auto i) { return i == 1; }); if (ret == false) { return true; @@ -493,12 +356,12 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return true; } } else if (optype == "NonZero") { - if (!can_eval_node_argument(graph_viewer, node, {0}, logger, input_nodes)) + if (!canEvalNodeArgument(graph_viewer, node, {0}, input_nodes)) { return true; } } else if (optype == "OneHot") { - if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes)) + if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { return true; } @@ -506,7 +369,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co const auto& args = node->InputDefs(); // if pad size is not constant, migraphx cannot support if (args.size() >= 2) { - if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes)) + if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { return true; } @@ -527,7 +390,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co // input value only applied to constant mode if (mode == "constant") { if (args.size() == 3) { - if (!can_eval_node_argument(graph_viewer, node, {2}, logger, input_nodes)) + if (!canEvalNodeArgument(graph_viewer, node, {2}, input_nodes)) { return true; } @@ -537,20 +400,20 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co auto arg_num = node->InputDefs().size(); std::vector vec(arg_num); std::iota(vec.begin(), vec.end(), 0); - if (!can_eval_node_argument(graph_viewer, node, vec, logger, input_nodes)) + if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes)) { return true; } } else if (optype == "Reshape") { const auto& args = node->InputDefs(); if (args.size() == 2) { - if (can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes)) + if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { return false; } return true; } - } else if (optype == "Resize") { + } else if (optype == "Resize" or optype == "Upsample") { const auto& attributes = node->GetAttributes(); auto ct_attr = attributes.find("coordinate_transformation_mode"); if (ct_attr != attributes.end()) { @@ -575,7 +438,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co { std::vector indices(args.size() - 1); std::iota(indices.begin(), indices.end(), 1); - if (can_eval_node_argument(graph_viewer, node, indices, logger, input_nodes)) + if (canEvalNodeArgument(graph_viewer, node, indices, input_nodes)) { return false; } @@ -584,7 +447,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co } else if (optype == "ReduceSum") { const auto& args = node->InputDefs(); if (args.size() == 2) { - if (can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes)) + if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { return false; } @@ -598,15 +461,15 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co std::vector vec(arg_num); std::iota(vec.begin(), vec.end(), 0); vec.erase(vec.begin()); - if (!can_eval_node_argument(graph_viewer, node, vec, logger, input_nodes)) + if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes)) { return true; } const auto& attributes = node->GetAttributes(); if (attributes.count("starts") > 0 and attributes.count("ends") > 0) { - auto starts = to_vector((*attributes.find("starts")).second.ints()); - auto ends = to_vector((*attributes.find("ends")).second.ints()); + auto starts = toVector((*attributes.find("starts")).second.ints()); + auto ends = toVector((*attributes.find("ends")).second.ints()); for (std::size_t i = 0; i < starts.size(); ++i) { if (starts.at(i) > ends.at(i)) { return true; @@ -636,26 +499,26 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co const auto& args = node->InputDefs(); if (args.size() == 2) { - if (can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes)) + if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { return false; } return true; } } else if (optype == "Tile") { - if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes)) + if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { return true; } } else if (optype == "TopK") { - if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes)) + if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { return true; } } else if (optype == "Unsqueeze" or optype == "Squeeze") { const auto& args = node->InputDefs(); if (args.size() == 2) { - if (can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes)) + if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { return false; } @@ -690,8 +553,7 @@ void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::v const auto& args = node->InputDefs(); if (args.size() == 2) { std::vector node_inputs; - // if (can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, node_inputs)) - if (can_eval_node_argument(graph_viewer, node, {1}, logger, node_inputs)) + if (canEvalNodeArgument(graph_viewer, node, {1}, node_inputs)) { return (not std::all_of(node_inputs.begin(), node_inputs.end(), [&](auto index) { return std::find(git.begin(), git.end(), index) != git.end(); @@ -796,7 +658,7 @@ static bool IsNodeSupported(const std::set& op_set, } // check that some modes might not be supported in migraphx for some operators - if (domain == kOnnxDomain && IsUnsupportedOpMode(graph_viewer, node, logger)) { + if (domain == kOnnxDomain && IsUnsupportedOpMode(graph_viewer, node)) { // not supported, then check the constant folding capability of migraphx // to see whether it is supported return false; @@ -959,8 +821,6 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st output_names.push_back(name); } - - // Generate unique kernel name for MIGraphX subgraph uint64_t model_hash = 0; int id = GenerateMetaDefId(graph, model_hash); @@ -992,21 +852,22 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, /*out*/ std::unordered_set& mgx_required_initializers, const logging::Logger& logger) { static std::set mgx_supported_ops = {"Abs", "Acos", "Acosh", "Add", "And", - "ArgMax", "ArgMin", "Asin", "Asinh", "Atan", "Atanh", "AveragePool", - "BatchNormalization", "Cast", "Ceil", "Clip", "Concat", "Constant", "ConstantFill", - "ConstantOfShape", "Conv", "Cos", "Cosh", "DepthToSpace", "DequantizeLinear", "Div", - "Dropout", "Elu", "Equal", "Erf", "Exp", "Expand", "Flatten", "Floor", "GRU", "Gather", - "GatherElements", "Gemm", "GlobalAveragePool", "GlobalMaxPool", "Greater", "Identity", - "If", "ImageScaler", "InstanceNormalization", "LRN", "LSTM", "LeakyRelu", "Less", - "LessOrEqual", "Log", "LogSoftmax", "Loop", "MatMul", "Max", "MaxPool", "Min", "Mul", - "Multinomial", "Neg", "NonZero", "Not", "NonMaxSuppression", "OneHot", "Or", "Pad", "Pow", - "PRelu", "QuantizeLinear", "RNN", "RandomNormal", "RandomNormalLike", "RandomUniform", - "RandomUniformLike", "Range", "Reciprocal", "ReduceL1", "ReduceL2", "ReduceLogSum", - "ReduceLogSumExp", "ReduceMax", "ReduceMean", "ReduceMin", "ReduceProd", "ReduceSum", - "ReduceSumSquare", "Relu", "Reshape", "Resize", "Roialign", "Round", "Scatter", "Selu", - "Shape", "Sigmoid", "Sign", "Sin", "Sinh", "Slice", "Softmax", "SpaceToDepth", "Split", - "Sqrt", "Squeeze", "Sub", "Sum", "Tan", "Tanh", "Tile", "TopK", "Transpose", "Unsqueeze", - "Where", "Xor"}; + "ArgMax", "ArgMin", "Asin", "Asinh", "Atan", "Atanh", "ATen", "AveragePool", + "BatchNormalization", "Cast", "Ceil", "Celu", "Clip", "Concat", "Constant", "ConstantFill", + "ConstantOfShape", "Conv", "ConvInteger", "ConvTranspose", "Cos", "Cosh", "CumSum", + "DepthToSpace", "DequantizeLinear", "Div", "Dropout", "Elu", "Equal", "Erf", "Exp", + "Expand", "EyeLike", "Flatten", "Floor", "GRU", "Gather", "GatherElements", "Gemm", "GlobalAveragePool", + "GlobalMaxPool", "Greater", "GreaterOrEqual", "HardSigmoid", "HardSwish", "Identity", + "If", "ImageScaler", "InstanceNormalization", "LeakyRelu", "Less", "LessOrEqual", + "Log", "LogSoftmax", "Loop", "LpNormalization", "LRN", "LSTM", "MatMul", "MatMulInteger", "Max", "MaxPool", + "Mean", "Min", "Mul", "Multinomial", "Neg", "NonMaxSuppression", "NonZero", "Not", + "OneHot", "Or", "Pad", "Pow", "PRelu", "QuantizeLinear", "RandomNormal", "RandomNormalLike", + "RandomUniform", "RandomUniformLike", "Range", "Reciprocal", "ReduceL1", "ReduceL2", + "ReduceLogSum", "ReduceLogSumExp", "ReduceMax", "ReduceMean", "ReduceMin", "ReduceProd", + "ReduceSum", "ReduceSumSquare", "Relu", "Reshape", "Resize", "RNN", "Roialign", "Round", + "Scatter", "ScatterElements", "ScatterND", "Selu", "Shape", "Sigmoid", "Sign", "Sin", "Sinh", "Slice", "Softmax", "Softplus", + "Softsign", "SpaceToDepth", "Split", "Sqrt", "Squeeze", "Sub", "Sum", "Tan", "Tanh", + "ThresholdedRelu", "Tile", "TopK", "Transpose", "Unsqueeze", "Upsample", "Where", "Xor"}; std::vector unsupported_nodes_idx; for (const auto& node_idx : graph_viewer.GetNodesInTopologicalOrder()) { if (IsNodeSupported(mgx_supported_ops, graph_viewer, node_idx, logger)) { @@ -1061,10 +922,17 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v auto model_proto = model->ToProto(); ToGraphProtoInternal(graph_viewer, *model_proto->mutable_graph()); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - std::string onnx_string_buffer; model_proto->SerializeToString(onnx_string_buffer); + // dump onnx file if environment var is set + if (dump_model_ops_) { + std::string model_name = graph_viewer.Name() + ".onnx"; + std::ofstream ofs(model_name); + ofs.write(onnx_string_buffer.c_str(), onnx_string_buffer.size()); + ofs.close(); + } + // This is a list of initializers that migraphx considers as constants. // Example weights, reshape shape etc. std::unordered_set mgx_required_initializers; @@ -1076,6 +944,15 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v auto sub_graph = GetSubGraph(node_indices, graph_viewer); result.push_back(ComputeCapability::Create(std::move(sub_graph))); } else { // unsupported_nodes_idx.empty() + if (dump_model_ops_) { + LOGS_DEFAULT(INFO) << "============= Unsupported nodes ====================" << std::endl; + for (auto idx : unsupported_nodes) + { + LOGS_DEFAULT(INFO) << graph_viewer.GetNode(idx)->OpType() << std::endl; + } + LOGS_DEFAULT(INFO) << "************* Unsupported nodes ********************" << std::endl; + } + if (unsupported_nodes.size() > 10) { return result; @@ -1170,6 +1047,13 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::string onnx_string_buffer; model_proto->SerializeToString(onnx_string_buffer); + if (dump_model_ops_) { + std::string onnx_name = fused_node->Name() + ".onnx"; + std::ofstream ofs(onnx_name); + ofs.write(onnx_string_buffer.data(), onnx_string_buffer.size()); + ofs.close(); + } + std::vector input_names, output_names; no_input_shape = no_input_shape or get_input_output_names(*graph_body_viewer, input_names, output_names); @@ -1202,7 +1086,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::unique_ptr p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, - map_no_input_shape_[context->node_name], fp16_enable_}; + map_no_input_shape_[context->node_name], fp16_enable_, dump_model_ops_}; *state = p.release(); return 0; }; @@ -1296,7 +1180,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& ort.ReleaseTensorTypeAndShapeInfo(tensor_info); migraphx_shape_datatype_t mgx_type; - get_migraphx_type(tensor_type, mgx_type); + getMIGraphXType(tensor_type, mgx_type); auto mgx_s = param_shapes[name]; if (mgx_type != mgx_s.type()) { diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 1bf222c193..9c5ac11fe3 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -16,6 +16,7 @@ namespace onnxruntime { namespace migraphx_env_vars { static const std::string kFP16Enable = "ORT_MIGRAPHX_FP16_ENABLE"; +static const std::string dumpModelOps = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; }; // Information to construct kernel function state. @@ -31,6 +32,7 @@ struct MIGraphXFuncState { OrtMutex* mgx_mu_ptr = nullptr; bool no_input_shape = false; bool fp16_enable = false; + bool dump_model_ops = false; }; // Logical device representation. @@ -58,6 +60,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { private: bool fp16_enable_ = false; + bool dump_model_ops_ = false; int device_id_; migraphx::target t_; OrtMutex mgx_mu_; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h new file mode 100644 index 0000000000..b0bae5581e --- /dev/null +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License + +#pragma once +#include "core/session/onnxruntime_cxx_api.h" +#include "core/framework/allocatormgr.h" +#include "core/framework/execution_provider.h" + +namespace onnxruntime { + +bool IsGraphInput(const GraphViewer& graph, const std::string& name) +{ + const auto& graph_inputs = graph.GetInputs(); + std::vector input_names(graph_inputs.size()); + std::transform(graph_inputs.begin(), graph_inputs.end(), input_names.begin(), [](auto in) { + return in->Name(); + }); + return (std::find(input_names.begin(), input_names.end(), name) != input_names.end()); +} + +bool IsGraphInitializer(const GraphViewer& graph, const std::string& name, bool check_outer_scope = true) { + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + return graph.GetInitializedTensor(name, initializer); +} + +const Node* GetInputNode(const Node& node, int arg_index) { + int index = 0; + for (auto nit = node.InputNodesBegin(); nit != node.InputNodesEnd(); ++nit, ++index) + { + if (index == arg_index) + { + return &(*nit); + } + } + + return nullptr; +} + +std::size_t getNodeInputNum(const Node& node) +{ + std::size_t node_num = 0; + for(auto it = node.InputNodesBegin(); it != node.InputNodesEnd(); ++it) + { + node_num++; + } + + return node_num; +} + +bool isInputNode(const Node* node, const std::string& name) +{ + auto outputs = node->OutputDefs(); + return std::any_of(outputs.begin(), outputs.end(), [&](auto out) { + return (out->Name() == name); + }); +} + +bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector& input_nodes) +{ + if (node == nullptr) + { + return false; + } + + std::vector in_nodes; + for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit) + { + in_nodes.push_back(&(*nit)); + } + + if (node->OpType() == "Shape") + { + input_nodes.push_back(node->Index()); + return true; + } + + auto inputs = node->InputDefs(); + for (std::size_t i = 0; i < inputs.size(); ++i) + { + const std::string& input_name = inputs.at(i)->Name(); + // If it is an initializer, it can be constant folded + if (IsGraphInitializer(graph, input_name)) + { + continue; + } + + // Input for sure cannot be constant folded + if (IsGraphInput(graph, input_name)) + { + return false; + } + + // find the node corresponding to the name + auto nit = std::find_if(in_nodes.begin(), in_nodes.end(), [&](auto n) { + return isInputNode(n, input_name); + }); + if (nit == in_nodes.end()) + { + return false; + } + + auto input_node = (*nit); + // shape node, it is OK + if (input_node->OpType() == "Shape") + { + continue; + } + + if (canEvalShapeGeneral(graph, input_node, input_nodes)) + { + continue; + } + + return false; + } + + input_nodes.push_back(node->Index()); + return true; +} + +bool canEvalNodeArgument(const GraphViewer& graph, const Node* node, std::vector indices, std::vector& input_nodes) +{ + input_nodes.clear(); + std::vector in_nodes; + for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit) + { + in_nodes.push_back(&(*nit)); + } + + auto inputs = node->InputDefs(); + for (auto index : indices) + { + // an initializer itself is a constant + auto input_name = inputs.at(index)->Name(); + if (IsGraphInitializer(graph, input_name)) + { + continue; + } + + // Input cannot be constant folded + if (IsGraphInput(graph, input_name)) + { + return false; + } + + // find the node corresponding to the name + auto nit = std::find_if(in_nodes.begin(), in_nodes.end(), [&](auto n) { + return isInputNode(n, input_name); + }); + if (nit == in_nodes.end()) + { + return false; + } + + if (!canEvalShapeGeneral(graph, *nit, input_nodes)) + { + return false; + } + } + + return true; +} + +} diff --git a/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc b/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc new file mode 100644 index 0000000000..123c9b5d5d --- /dev/null +++ b/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/graph/onnx_protobuf.h" +#include "core/session/inference_session.h" +#include "test/providers/provider_test_utils.h" +#include "test/framework/test_utils.h" +#include "gtest/gtest.h" +#include "test/util/include/default_providers.h" +#include "test/util/include/scoped_env_vars.h" +#include "core/providers/migraphx/migraphx_execution_provider_utils.h" +#include +#include + +using namespace std; +using namespace ONNX_NAMESPACE; +using namespace ::onnxruntime::logging; + +namespace onnxruntime { + +namespace test { + +template +void VerifyOutputs(const std::vector& fetches, const std::vector& expected_dims, + const std::vector& expected_values) { + ASSERT_EQ(1, fetches.size()); + auto& rtensor = fetches.front().Get(); + TensorShape expected_shape(expected_dims); + ASSERT_EQ(expected_shape, rtensor.Shape()); + const std::vector found(rtensor.template Data(), rtensor.template Data() + expected_values.size()); + ASSERT_EQ(expected_values, found); +} + +/** + * Create a simple model with two inputs and one initializer. + * input: "X", "Y" and "Z" + * output: "M" + * + * "X" "Y" + * \ / + * "Ini" Add + * | \ / + * | Add + * | \ + * \ Shape + * \ / + * Reshape + * | + * M + */ +void CreateBaseModel(onnxruntime::Model& model, std::vector dims) { + auto& graph = model.MainGraph(); + std::vector inputs; + std::vector outputs; + + // FLOAT tensor + ONNX_NAMESPACE::TypeProto float_tensor; + float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + + for (auto dim: dims) { + float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); + } + + // INT tensor + ONNX_NAMESPACE::TypeProto int64_tensor; + int64_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + int64_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3); + + // constant + TensorProto value_tensor; + value_tensor.add_dims(1); + value_tensor.add_float_data(1.f); + value_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + value_tensor.set_name("Ini"); + graph.AddInitializedTensor(value_tensor); + + // Create node1 (Add) + auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor); + auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor); + inputs.push_back(&input_arg_1); + inputs.push_back(&input_arg_2); + auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", &float_tensor); + outputs.push_back(&output_arg); + graph.AddNode("node_1", "Add", "node 1.", inputs, outputs); + + // Create node2 (Add) + auto& input_arg_3 = graph.GetOrCreateNodeArg("Ini", &float_tensor); + inputs.clear(); + inputs.push_back(&output_arg); + inputs.push_back(&input_arg_3); + auto& output_arg_2 = graph.GetOrCreateNodeArg("M", &float_tensor); + outputs.clear(); + outputs.push_back(&output_arg_2); + graph.AddNode("node_2", "Add", "node 2.", inputs, outputs); + + // Create node3 (Shape) + inputs.clear(); + outputs.clear(); + inputs.push_back(&output_arg_2); + auto& output_arg_3 = graph.GetOrCreateNodeArg("S", &int64_tensor); + outputs.push_back(&output_arg_3); + graph.AddNode("node_3", "Shape", "node 3.", inputs, outputs); + + // Create node4 (Reshape) + inputs.clear(); + outputs.clear(); + inputs.push_back(&input_arg_3); + inputs.push_back(&output_arg_3); + auto& output_arg_4 = graph.GetOrCreateNodeArg("R", &float_tensor); + outputs.push_back(&output_arg_4); + graph.AddNode("node_4", "Reshape", "node 4.", inputs, outputs); + + auto status = graph.Resolve(); + ASSERT_TRUE(status.IsOK()); +} + +TEST(MIGraphXExecutionProviderTest, GraphInputName) { + std::string graph_name = "migraphx_util_test"; + onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger()); + std::vector dims = {1, 3, 2}; + + CreateBaseModel(model, dims); + + auto& graph = model.MainGraph(); + GraphViewer gv(graph); + + ASSERT_EQ(IsGraphInput(gv, "X"), true); +} + +TEST(MIGraphXExecutionProviderTest, GraphInitializer) { + std::string graph_name = "migraphx_util_test"; + onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger()); + std::vector dims = {1, 3, 2}; + + CreateBaseModel(model, dims); + + auto& graph = model.MainGraph(); + GraphViewer gv(graph); + + ASSERT_EQ(IsGraphInitializer(gv, "Ini"), true); +} + +TEST(MIGraphXExecutionProviderTest, NodeInputNum) { + std::string graph_name = "migraphx_util_test"; + onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger()); + std::vector dims = {1, 3, 2}; + + CreateBaseModel(model, dims); + + auto& graph = model.MainGraph(); + GraphViewer gv(graph); + + // get the first add node + const auto& node0 = gv.GetNode(0); + const auto& node1 = gv.GetNode(1); + + ASSERT_EQ(getNodeInputNum(*node0), 0); + ASSERT_EQ(getNodeInputNum(*node1), 1); +} + +TEST(MIGraphXExecutionProviderTest, IsNodeInput) { + std::string graph_name = "migraphx_util_test"; + onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger()); + std::vector dims = {1, 3, 2}; + + CreateBaseModel(model, dims); + + auto& graph = model.MainGraph(); + GraphViewer gv(graph); + + // get the first add node + const auto& node2 = gv.GetNode(1); + ASSERT_EQ(isInputNode(node2, "M"), true); +} + +TEST(MIGraphXExecutionProviderTest, canEvalArgument) { + std::string graph_name = "migraphx_util_test"; + onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger()); + std::vector dims = {1, 3, 2}; + + CreateBaseModel(model, dims); + + auto& graph = model.MainGraph(); + GraphViewer gv(graph); + + // get the first add node + const auto& node2 = gv.GetNode(3); + std::vector input_nodes; + ASSERT_EQ(canEvalNodeArgument(gv, node2, {1}, input_nodes), true); +} + +} // namespace test +} // namespace onnxruntime