From e7d7fa8fa2c85f5b34bfb4c67f0dc181c5d43dc1 Mon Sep 17 00:00:00 2001 From: Shucai Xiao Date: Tue, 22 Jun 2021 15:39:51 -0500 Subject: [PATCH] Update migraphx to rocm4.2 (#7994) * update dockerfile for migraphx ep * update to rocm4.2 * code cleanup * fix error related to onnx unit tests --- dockerfiles/Dockerfile.migraphx | 11 +- .../migraphx/migraphx_execution_provider.cc | 126 ++++++++++++------ .../test/python/onnx_backend_test_series.py | 10 +- 3 files changed, 100 insertions(+), 47 deletions(-) diff --git a/dockerfiles/Dockerfile.migraphx b/dockerfiles/Dockerfile.migraphx index a395e69c50..b97f234b9e 100644 --- a/dockerfiles/Dockerfile.migraphx +++ b/dockerfiles/Dockerfile.migraphx @@ -21,10 +21,15 @@ ENV LANG C.UTF-8 # Install rocm RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl && \ curl -sL http://repo.radeon.com/rocm/apt/debian/rocm.gpg.key | apt-key add - && \ - sh -c 'echo deb [arch=amd64] http://repo.radeon.com/rocm/apt/3.7/ xenial main > /etc/apt/sources.list.d/rocm.list' + sh -c 'echo deb [arch=amd64] http://repo.radeon.com/rocm/apt/4.2/ xenial main > /etc/apt/sources.list.d/rocm.list' RUN apt-get update &&\ - apt-get install -y sudo git bash build-essential cmake rocm-dkms libpython3.6-dev python3-pip miopen-hip rocblas half + apt-get install -y sudo git bash build-essential rocm-dev libpython3.6-dev python3-pip miopen-hip \ + rocblas half aria2 + +RUN aria2c -q -d /tmp -o cmake-3.20.3-Linux-x86_64.tar.gz \ +https://github.com/Kitware/CMake/releases/download/v3.20.3/cmake-3.20.3-Linux-x86_64.tar.gz &&\ +tar -zxf /tmp/cmake-3.20.3-Linux-x86_64.tar.gz --strip=1 -C /usr # Install rbuild RUN pip3 install https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz @@ -34,7 +39,7 @@ ENV PATH /opt/miniconda/bin:/code/cmake-3.20.3-linux-x86_64/bin:${PATH} # Install MIGraphX from source RUN mkdir -p /migraphx RUN cd /migraphx && git clone --depth=1 --branch migraphx_for_ort https://github.com/ROCmSoftwarePlatform/AMDMIGraphX src -RUN cd /migraphx && rbuild package --cxx /opt/rocm-3.7.0/llvm/bin/clang++ -d /migraphx/deps -B /migraphx/build -S /migraphx/src/ -DPYTHON_EXECUTABLE=/usr/bin/python3 +RUN cd /migraphx && rbuild package --cxx /opt/rocm-4.2.0/llvm/bin/clang++ -d /migraphx/deps -B /migraphx/build -S /migraphx/src/ -DPYTHON_EXECUTABLE=/usr/bin/python3 RUN dpkg -i /migraphx/build/*.deb RUN rm -rf /migraphx diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 0f0e3a6d26..ed2d2811c4 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -12,7 +12,6 @@ #include "core/graph/graph_utils.h" #include "core/platform/env.h" #include "core/session/onnxruntime_cxx_api.h" -#include "core/optimizer/reshape_fusion.h" #include "migraphx_inc.h" #include "migraphx_execution_provider.h" #include "hip_allocator.h" @@ -204,7 +203,8 @@ static bool get_migraphx_type(ONNXTensorElementDataType type, return true; } -static bool can_eval_shape_general(const Graph& graph, const Node* node, const logging::Logger& logger) + +static bool can_eval_shape_general(const Graph& graph, const Node* node, const logging::Logger& logger, std::vector& input_nodes) { if (node == nullptr) { @@ -213,6 +213,7 @@ static bool can_eval_shape_general(const Graph& graph, const Node* node, const l if (node->OpType() == "Shape") { + input_nodes.push_back(node->Index()); return true; } @@ -245,7 +246,7 @@ static bool can_eval_shape_general(const Graph& graph, const Node* node, const l continue; } - if (can_eval_shape_general(graph, input_node, logger)) + if (can_eval_shape_general(graph, input_node, logger, input_nodes)) { continue; } @@ -253,11 +254,15 @@ static bool can_eval_shape_general(const Graph& graph, const Node* node, const l return false; } + input_nodes.push_back(node->Index()); + return true; } -static bool can_eval_node_argument(const Graph& graph, const Node* node, std::vector indices, const logging::Logger& logger) +static bool can_eval_node_argument(const Graph& graph, const Node* node, std::vector indices, const logging::Logger& logger, std::vector& input_nodes) { + input_nodes.clear(); + for (auto& arg_index : indices) { const std::string& input_name = graph_utils::GetNodeInputName(*node, arg_index); @@ -275,7 +280,7 @@ static bool can_eval_node_argument(const Graph& graph, const Node* node, std::ve } auto input_node = graph_utils::GetInputNode(*node, arg_index); - if (!can_eval_shape_general(graph, input_node, logger)) + if (!can_eval_shape_general(graph, input_node, logger, input_nodes)) { return false; } @@ -285,6 +290,7 @@ static bool can_eval_node_argument(const Graph& graph, const Node* node, std::ve } static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, const Node* node, const logging::Logger& logger) { + std::vector input_nodes; const auto& optype = node->OpType(); // const auto& initializers = graph_viewer.GetAllInitializedTensors(); if (optype == "ArgMax" or optype == "ArgMin") { @@ -295,7 +301,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return true; } } else if (optype == "ConstantOfShape") { - if (!can_eval_node_argument(graph_viewer.GetGraph(), node, {0}, logger)) + if (!can_eval_node_argument(graph_viewer.GetGraph(), node, {0}, logger, input_nodes)) { return true; } @@ -320,7 +326,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co } } else if (optype == "Expand") { // MIGraphX only supports constant shape input values - if (!can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger)) + if (!can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, input_nodes)) { return true; } @@ -390,12 +396,12 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return true; } } else if (optype == "NonZero") { - if (!can_eval_node_argument(graph_viewer.GetGraph(), node, {0}, logger)) + if (!can_eval_node_argument(graph_viewer.GetGraph(), node, {0}, logger, input_nodes)) { return true; } } else if (optype == "OneHot") { - if (!can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger)) + if (!can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, input_nodes)) { return true; } @@ -403,7 +409,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co const auto& args = node->InputDefs(); // if pad size is not constant, migraphx cannot support if (args.size() >= 2) { - if (!can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger)) + if (!can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, input_nodes)) { return true; } @@ -424,7 +430,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co // input value only applied to constant mode if (mode == "constant") { if (args.size() == 3) { - if (!can_eval_node_argument(graph_viewer.GetGraph(), node, {2}, logger)) + if (!can_eval_node_argument(graph_viewer.GetGraph(), node, {2}, logger, input_nodes)) { return true; } @@ -434,14 +440,23 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co auto arg_num = node->InputDefs().size(); std::vector vec(arg_num); std::iota(vec.begin(), vec.end(), 0); - if (!can_eval_node_argument(graph_viewer.GetGraph(), node, vec, logger)) + if (!can_eval_node_argument(graph_viewer.GetGraph(), node, vec, logger, input_nodes)) { return true; } } else if (optype == "Reshape") { const auto& args = node->InputDefs(); if (args.size() == 2) { - if (can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger)) + if (can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, input_nodes)) + { + return false; + } + return true; + } + } else if (optype == "ReduceSum") { + const auto& args = node->InputDefs(); + if (args.size() == 2) { + if (can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, input_nodes)) { return false; } @@ -455,7 +470,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co std::vector vec(arg_num); std::iota(vec.begin(), vec.end(), 0); vec.erase(vec.begin()); - if (!can_eval_node_argument(graph_viewer.GetGraph(), node, vec, logger)) + if (!can_eval_node_argument(graph_viewer.GetGraph(), node, vec, logger, input_nodes)) { return true; } @@ -490,18 +505,36 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return true; } } + + const auto& args = node->InputDefs(); + if (args.size() == 2) { + if (can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, input_nodes)) + { + return false; + } + return true; + } } else if (optype == "Tile") { - if (!can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger)) + if (!can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, input_nodes)) { return true; } + } else if (optype == "Unsqueeze" or optype == "Squeeze") { + const auto& args = node->InputDefs(); + if (args.size() == 2) { + if (can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, input_nodes)) + { + return false; + } + return true; + } } //Op doesn't fall into known any of unsupported modes. return false; } -void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::vector>& clusters) +void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::vector>& clusters, const logging::Logger& logger) { // If the number of nodes in the graph is less than 5, do nothing // this is to deal with onnx unit tests @@ -516,6 +549,28 @@ void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::v std::unordered_set op_names = {"AveragePool", "Conv", "Gemm", "LRN", "MatMul", "MaxPool"}; auto it = std::remove_if(clusters.begin(), clusters.end(), [&](auto git) { + for (auto index : git) + { + auto node = graph_viewer.GetNode(index); + if (node->OpType() == "Reshape") + { + const auto& args = node->InputDefs(); + if (args.size() == 2) { + std::vector node_inputs; + if (can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, node_inputs)) + { + return (not std::all_of(node_inputs.begin(), node_inputs.end(), [&](auto index) { + return std::find(git.begin(), git.end(), index) != git.end(); + })); + } + else + { + return true; + } + } + } + } + // if 6 operators or more if (git.size() > 5) { @@ -640,21 +695,19 @@ static std::vector GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, /*out*/ std::unordered_set& mgx_required_initializers, const logging::Logger& logger) { - static std::set mgx_supported_ops = {"Abs", "Acos", "Acosh", "Add", "ArgMax", "ArgMin", - "Asin", "Asinh", "Atan", "Atanh", "AveragePool", "BatchNormalization", "Cast", "Ceil", "Clip", - "Concat", "Constant", "ConstantFill", "ConstantOfShape", "Conv", "Cos", "Cosh", - "Div", "Dropout", "Elu", "Equal", "Erf", "Exp", "Expand", "Flatten", "Floor", "GRU", "Gather", - "GatherElements", "Gemm", "GlobalAveragePool", "GlobalMaxPool", "Identity", "ImageScaler", - "InstanceNormalization", "LRN", "LSTM", "LeakyRelu", "Log", "LogSoftmax", "MatMul", "Max", - "MaxPool", "Min", "Mul", "Neg", "NonZero", "OneHot", "Pad", "Pow", "PRelu", - "RNN", "Range", "Reciprocal", "ReduceL1", "ReduceL2", "ReduceLogSum", "ReduceLogSumExp", "ReduceMax", - "ReduceMean", "ReduceMin", "ReduceProd", "ReduceSum", "ReduceSumSquare", "Relu", "Reshape", - "Round", "Shape", "Sigmoid", "Sign", "Sin", "Sinh", "Slice", "Softmax", "Split", "Sqrt", "Squeeze", - "Sub", "Sum", "Tan", "Tanh", "Tile", "Transpose", "Unsqueeze", "Where"}; + static std::set mgx_supported_ops = {"Abs", "Acos", "Acosh", "Add", "And", "ArgMax", "ArgMin", + "Asin", "Asinh", "Atan", "Atanh", "AveragePool", "BatchNormalization", "Cast", "Ceil", "Clip", + "Concat", "Constant", "ConstantFill", "ConstantOfShape", "Conv", "Cos", "Cosh", "DequantizeLinear", + "Div", "Dropout", "Elu", "Equal", "Erf", "Exp", "Expand", "Flatten", "Floor", "GRU", "Gather", + "GatherElements", "Gemm", "GlobalAveragePool", "GlobalMaxPool", "Greater", "Identity", "ImageScaler", + "InstanceNormalization", "LRN", "LSTM", "LeakyRelu", "Less", "LessOrEqual", "Log", "LogSoftmax", + "MatMul", "Max", "MaxPool", "Min", "Mul", "Neg", "NonZero", "OneHot", "Or", "Pad", "Pow", "PRelu", + "QuantizeLinear", "RNN", "Range", "Reciprocal", "ReduceL1", "ReduceL2", "ReduceLogSum", "ReduceLogSumExp", + "ReduceMax", "ReduceMean", "ReduceMin", "ReduceProd", "ReduceSum", "ReduceSumSquare", "Relu", "Reshape", + "Round", "Selu", "Shape", "Sigmoid", "Sign", "Sin", "Sinh", "Slice", "Softmax", "Split", "Sqrt", "Squeeze", + "Sub", "Sum", "Tan", "Tanh", "Tile", "Transpose", "Unsqueeze", "Where", "Xor"}; std::vector unsupported_nodes_idx; for (const auto& node_idx : graph_viewer.GetNodesInTopologicalOrder()) { - // auto node = graph_viewer.GetNode(node_idx); - // print_node_info(node); if (IsNodeSupported(mgx_supported_ops, graph_viewer, node_idx, logger)) { // Collect inputs that are initializers graph_viewer.GetNode(node_idx)->ForEachDef([&mgx_required_initializers, &graph_viewer](const onnxruntime::NodeArg& node_arg, bool is_input) { @@ -810,6 +863,7 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v std::unordered_map map_dim_param_values; onnxruntime::Graph& graph_build = model.MainGraph(); + for (const auto& node : graph_viewer.Nodes()) { std::vector inputs, outputs; for (auto input : node.InputDefs()) { @@ -835,6 +889,7 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); auto status = graph_build.Resolve(); + std::string onnx_string_buffer; model_proto.SerializeToString(&onnx_string_buffer); @@ -870,18 +925,10 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v AppendNodesToSubGraph(graph_viewer.GetNodesInTopologicalOrder(), inputs, outputs, result); } else { // unsupported_nodes_idx.empty() - // migraphx cannot handle Loop, If, and SoftmaxCrossEntropyLoss for now, - // so if a model contain any of these operators, fall back to CPU - std::unordered_set vec_ops = {"If", "Loop", "SoftmaxCrossEntropyLoss"}; - if (std::any_of(unsupported_nodes.begin(), unsupported_nodes.end(), [&](auto i) { - return (vec_ops.count(graph_viewer.GetNode(i)->OpType()) > 0); - })) { - return result; - } - auto mgx_clusters = GetPartitionedSubgraphs(graph_viewer.GetNodesInTopologicalOrder(), unsupported_nodes); + // check whether a subgrap should fallback to CPU - SubgraphPostProcessing(graph_viewer, mgx_clusters); + SubgraphPostProcessing(graph_viewer, mgx_clusters, *GetLogger()); for (const auto& this_cluster : mgx_clusters) { std::vector cluster_inputs, cluster_outputs; @@ -970,7 +1017,6 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::vector& node_compute_funcs) { migraphx::onnx_options options; bool no_input_shape = false; - for (const auto& fused_node : fused_nodes) { // map parameter input name to index std::unordered_map input_name_index; diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index 91b9cb828f..28914dcf6d 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -77,9 +77,7 @@ def create_backend_test(testname=None): if c2.supports_device('MIGRAPHX'): current_failing_tests += [ - '^test_constant_pad_cpu', '^test_softmax_axis_1_cpu', '^test_softmax_axis_0_cpu', - '^test_softmax_default_axis_cpu', '^test_round_cpu', '^test_lrn_default_cpu', '^test_lrn_cpu', - '^test_logsoftmax_axis_0_cpu', '^test_logsoftmax_axis_1_cpu', '^test_logsoftmax_default_axis_cpu', + '^test_constant_pad_cpu', '^test_round_cpu', '^test_lrn_default_cpu', '^test_lrn_cpu', '^test_dynamicquantizelinear_expanded_cpu', '^test_dynamicquantizelinear_max_adjusted_cpu', '^test_dynamicquantizelinear_max_adjusted_expanded_cpu', '^test_dynamicquantizelinear_min_adjusted_cpu', '^test_dynamicquantizelinear_min_adjusted_expanded_cpu', @@ -88,7 +86,11 @@ def create_backend_test(testname=None): '^test_operator_symbolic_override_nested_cpu', '^test_negative_log_likelihood_loss', '^test_softmax_cross_entropy', - '^test_greater_equal', '^test_less_equal' + '^test_greater_equal', + '^test_if_seq_cpu', + '^test_loop13_seq_cpu', + '^test_sequence_insert_at_back_cpu', + '^test_sequence_insert_at_front_cpu' ] # Skip these tests for a "pure" DML onnxruntime python wheel. We keep these tests enabled for instances where both DML and CUDA