From 9d7284fc3bbaf473d103d759c4fa1a1e88693189 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Tue, 4 Aug 2020 16:27:22 -0700 Subject: [PATCH] Enable MatMul + Scale fusion (#4669) Update TransposeMatMul to support scaling of the matrix product by a constant scalar value (analogous to the GEMM alpha parameter). Rename TransposeMatMul to TransposeScaleMatMul. Fuse MatMul with surrounding Mul/Div with constant scalar into TransposeScaleMatMul. --- .../core/optimizer/graph_transformer_level.h | 10 +- .../contrib_ops/cpu/cpu_contrib_kernels.cc | 4 +- ...se_matmul.cc => transpose_scale_matmul.cc} | 17 +- ...pose_matmul.h => transpose_scale_matmul.h} | 5 +- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 12 +- ...se_matmul.cc => transpose_scale_matmul.cc} | 2 +- .../core/graph/contrib_ops/contrib_defs.cc | 13 +- .../core/optimizer/graph_transformer_utils.cc | 3 + .../core/optimizer/matmul_scale_fusion.cc | 282 ++++++++++++++++++ .../core/optimizer/matmul_scale_fusion.h | 39 +++ .../core/optimizer/matmul_transpose_fusion.cc | 8 +- .../core/providers/cuda/math/matmul.cc | 10 +- onnxruntime/core/providers/cuda/math/matmul.h | 12 +- ...t.cc => transpose_scale_matmul_op_test.cc} | 27 +- .../test/optimizer/graph_transform_test.cc | 119 +++++++- .../transform/fusion/matmul_scale_gen.py | 186 ++++++++++++ .../transform/fusion/matmul_scale_in0.onnx | Bin 0 -> 315 bytes .../fusion/matmul_scale_in0_in1.onnx | Bin 0 -> 387 bytes .../fusion/matmul_scale_in0_in1_out.onnx | Bin 0 -> 449 bytes .../matmul_scale_reused_input_scale.onnx | Bin 0 -> 449 bytes ...cale_transposescalematmul_in0_in1_out.onnx | Bin 0 -> 507 bytes .../matmul_scale_unfusable_div_not_scale.onnx | Bin 0 -> 323 bytes ...ul_scale_unfusable_scale_not_constant.onnx | Bin 0 -> 325 bytes ...tmul_scale_unfusable_scale_not_scalar.onnx | Bin 0 -> 334 bytes .../core/graph/gradient_builder.cc | 4 +- .../core/optimizer/graph_transformer_utils.cc | 83 +++--- .../core/optimizer/graph_transformer_utils.h | 17 +- .../core/session/training_session.cc | 78 ++--- .../optimizer/graph_transformer_utils_test.cc | 4 +- 29 files changed, 796 insertions(+), 139 deletions(-) rename onnxruntime/contrib_ops/cpu/{transpose_matmul.cc => transpose_scale_matmul.cc} (78%) rename onnxruntime/contrib_ops/cpu/{transpose_matmul.h => transpose_scale_matmul.h} (74%) rename onnxruntime/contrib_ops/cuda/math/{transpose_matmul.cc => transpose_scale_matmul.cc} (93%) create mode 100644 onnxruntime/core/optimizer/matmul_scale_fusion.cc create mode 100644 onnxruntime/core/optimizer/matmul_scale_fusion.h rename onnxruntime/test/contrib_ops/{transpose_matmul_op_test.cc => transpose_scale_matmul_op_test.cc} (83%) create mode 100644 onnxruntime/test/testdata/transform/fusion/matmul_scale_gen.py create mode 100644 onnxruntime/test/testdata/transform/fusion/matmul_scale_in0.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/matmul_scale_in0_in1.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/matmul_scale_in0_in1_out.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/matmul_scale_reused_input_scale.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/matmul_scale_transposescalematmul_in0_in1_out.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_div_not_scale.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_scale_not_constant.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_scale_not_scalar.onnx diff --git a/include/onnxruntime/core/optimizer/graph_transformer_level.h b/include/onnxruntime/core/optimizer/graph_transformer_level.h index 7aeb00ba66..cc05b219af 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_level.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_level.h @@ -7,11 +7,13 @@ namespace onnxruntime { +// Graph transformer level +// refer to docs/ONNX_Runtime_Graph_Optimizations.md for details enum class TransformerLevel : int { - Default = 0, - Level1, - Level2, - Level3, + Default = 0, // required transformers only + Level1, // basic optimizations + Level2, // extended optimizations + Level3, // layout optimizations // The max level should always be same as the last level. MaxLevel = Level3 }; diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 4bb12995f2..d75686bfcc 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -20,7 +20,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordConvEmbedding); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeScaleMatMul); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MurmurHash3); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MaxpoolWithMask); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Pad); @@ -130,7 +130,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/transpose_matmul.cc b/onnxruntime/contrib_ops/cpu/transpose_scale_matmul.cc similarity index 78% rename from onnxruntime/contrib_ops/cpu/transpose_matmul.cc rename to onnxruntime/contrib_ops/cpu/transpose_scale_matmul.cc index fcf5d3b1c4..0f3deeb02b 100644 --- a/onnxruntime/contrib_ops/cpu/transpose_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/transpose_scale_matmul.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "transpose_matmul.h" +#include "contrib_ops/cpu/transpose_scale_matmul.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/util/math.h" @@ -9,20 +9,21 @@ namespace onnxruntime { namespace contrib { ONNX_OPERATOR_KERNEL_EX( - TransposeMatMul, + TransposeScaleMatMul, kMSDomain, 1, kCpuExecutionProvider, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - TransposeMatMul); + TransposeScaleMatMul); -TransposeMatMul::TransposeMatMul(const OpKernelInfo& info) +TransposeScaleMatMul::TransposeScaleMatMul(const OpKernelInfo& info) : OpKernel{info} { - ORT_ENFORCE(info.GetAttr("transA", &trans_a_attr_).IsOK()); - ORT_ENFORCE(info.GetAttr("transB", &trans_b_attr_).IsOK()); + ORT_THROW_IF_ERROR(info.GetAttr("alpha", &alpha_attr_)); + ORT_THROW_IF_ERROR(info.GetAttr("transA", &trans_a_attr_)); + ORT_THROW_IF_ERROR(info.GetAttr("transB", &trans_b_attr_)); } -Status TransposeMatMul::Compute(OpKernelContext* context) const { +Status TransposeScaleMatMul::Compute(OpKernelContext* context) const { concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); const Tensor* A = context->Input(0); @@ -47,7 +48,7 @@ Status TransposeMatMul::Compute(OpKernelContext* context) const { trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, helper.M(), helper.N(), helper.K(), - 1.0f, + alpha_attr_, A->Data() + helper.LeftOffsets()[i], B->Data() + helper.RightOffsets()[i], 0.0f, diff --git a/onnxruntime/contrib_ops/cpu/transpose_matmul.h b/onnxruntime/contrib_ops/cpu/transpose_scale_matmul.h similarity index 74% rename from onnxruntime/contrib_ops/cpu/transpose_matmul.h rename to onnxruntime/contrib_ops/cpu/transpose_scale_matmul.h index 98daae10eb..a953db6bb7 100644 --- a/onnxruntime/contrib_ops/cpu/transpose_matmul.h +++ b/onnxruntime/contrib_ops/cpu/transpose_scale_matmul.h @@ -8,13 +8,14 @@ namespace onnxruntime { namespace contrib { -class TransposeMatMul final : public OpKernel { +class TransposeScaleMatMul final : public OpKernel { public: - TransposeMatMul(const OpKernelInfo& info); + TransposeScaleMatMul(const OpKernelInfo& info); Status Compute(OpKernelContext* context) const override; private: + float alpha_attr_; int64_t trans_a_attr_, trans_b_attr_; }; diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 1005423035..cec944654f 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -17,9 +17,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BiasGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeMatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeMatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeScaleMatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeScaleMatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeScaleMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Rfft); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Rfft); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Rfft); @@ -81,9 +81,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/math/transpose_matmul.cc b/onnxruntime/contrib_ops/cuda/math/transpose_scale_matmul.cc similarity index 93% rename from onnxruntime/contrib_ops/cuda/math/transpose_matmul.cc rename to onnxruntime/contrib_ops/cuda/math/transpose_scale_matmul.cc index eff9481f49..ba11171c8d 100644 --- a/onnxruntime/contrib_ops/cuda/math/transpose_matmul.cc +++ b/onnxruntime/contrib_ops/cuda/math/transpose_scale_matmul.cc @@ -9,7 +9,7 @@ namespace cuda { #define REGISTER_KERNEL_TYPED(T) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ - TransposeMatMul, \ + TransposeScaleMatMul, \ kMSDomain, \ 1, \ T, \ diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 4645bd71ad..7389f87d88 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1774,17 +1774,22 @@ Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy- ONNX_NAMESPACE::matmulShapeInference(ctx, 0, 1); }); - static const char* TransposeMatMul_doc = R"DOC( + static const char* TransposeScaleMatMul_doc = R"DOC( Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html )DOC"; - ONNX_CONTRIB_OPERATOR_SCHEMA(TransposeMatMul) + ONNX_CONTRIB_OPERATOR_SCHEMA(TransposeScaleMatMul) .SetDomain(kMSDomain) .SinceVersion(1) .SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL) - .SetDoc("TransposeMatMul") + .SetDoc("TransposeScaleMatMul") .Input(0, "A", "N-dimensional matrix A", "T") .Input(1, "B", "N-dimensional matrix B", "T") + .Attr( + "alpha", + "Scalar multiplier for the product of the input tensors.", + AttributeProto::FLOAT, + 1.0f) .Attr( "transA", "Whether A should be transposed on the last two dimensions before doing multiplication", @@ -1800,7 +1805,7 @@ Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy- "T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "Constrain input and output types to float tensors.") - .SetDoc(TransposeMatMul_doc) + .SetDoc(TransposeScaleMatMul_doc) .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); auto transAAttr = ctx.getAttribute("transA"); diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 3ea26f6438..e9ad6b6512 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -24,6 +24,7 @@ #include "core/optimizer/identity_elimination.h" #include "core/optimizer/layer_norm_fusion.h" #include "core/optimizer/matmul_add_fusion.h" +#include "core/optimizer/matmul_scale_fusion.h" #include "core/optimizer/nchwc_transformer.h" #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/reshape_fusion.h" @@ -144,6 +145,8 @@ std::vector> GenerateTransformers(TransformerL transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_execution_providers)); transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_execution_providers)); + + transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_execution_providers)); #endif } break; diff --git a/onnxruntime/core/optimizer/matmul_scale_fusion.cc b/onnxruntime/core/optimizer/matmul_scale_fusion.cc new file mode 100644 index 0000000000..7356f55177 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_scale_fusion.cc @@ -0,0 +1,282 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/matmul_scale_fusion.h" + +#include "onnx/defs/attr_proto_util.h" + +#include "core/common/optional.h" +#include "core/framework/data_types_internal.h" +#include "core/framework/tensorprotoutils.h" +#include "core/graph/graph_utils.h" +#include "core/graph/graph_viewer.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +namespace { +template +struct ExtractScalarAsFloatDispatchTarget { + Status operator()(const ONNX_NAMESPACE::TensorProto& tensor_proto, float& scalar_float) { + T scalar; + ORT_RETURN_IF_ERROR(utils::UnpackTensor(tensor_proto, &scalar, 1)); + scalar_float = static_cast(scalar); + return Status::OK(); + } +}; + +optional GetScalarConstantInitializer(const Graph& graph, const NodeArg& node_arg) { + const auto* initializer = graph_utils::GetConstantInitializer(graph, node_arg.Name()); + + if (!initializer) { + // not a constant + return {}; + } + + const auto* shape = node_arg.Shape(); + ORT_ENFORCE( + shape, + "Constant initializer NodeArg shape should not be null. NodeArg: ", node_arg.Name()); + + if (utils::GetTensorShapeFromTensorShapeProto(*shape).Size() != 1) { + // not a scalar + return {}; + } + + float scalar; + utils::MLTypeCallDispatcherRet< + Status, ExtractScalarAsFloatDispatchTarget, + uint32_t, uint64_t, int32_t, int64_t, MLFloat16, float, double> + dispatcher{initializer->data_type()}; + ORT_THROW_IF_ERROR(dispatcher.Invoke(*initializer, scalar)); + + return {scalar}; +} + +// gets the scale value and its input index if node is a fusable scale (Mul or Div by scalar constant) +optional> GetScaleFromNode( + const Graph& graph, const Node& scale_node, + const std::unordered_set& excluded_initializer_names) { + const auto is_excluded_initializer = + [&excluded_initializer_names](const NodeArg& node_arg) { + return excluded_initializer_names.find(node_arg.Name()) != excluded_initializer_names.end(); + }; + + if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Div", {7})) { + // (x / scale_reciprocal) + const auto div_inputs = scale_node.InputDefs(); + ORT_ENFORCE(div_inputs.size() == 2); + + const int scale_reciprocal_arg_index = 1; + const NodeArg& scale_reciprocal_node_arg = *div_inputs[scale_reciprocal_arg_index]; + + if (is_excluded_initializer(scale_reciprocal_node_arg)) return {}; + + const auto divisor = GetScalarConstantInitializer(graph, scale_reciprocal_node_arg); + + if (!divisor.has_value()) return {}; + + return {std::make_pair(1.0f / divisor.value(), scale_reciprocal_arg_index)}; + } + + if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Mul", {7})) { + // (x * scale) or (scale * x) + const auto mul_inputs = scale_node.InputDefs(); + ORT_ENFORCE(mul_inputs.size() == 2); + + for (int scale_arg_index = 0; scale_arg_index < 2; ++scale_arg_index) { + const NodeArg& scale_node_arg = *mul_inputs[scale_arg_index]; + + if (is_excluded_initializer(scale_node_arg)) continue; + + const auto multiplier = GetScalarConstantInitializer(graph, scale_node_arg); + + if (!multiplier.has_value()) continue; + + return {std::make_pair(multiplier.value(), scale_arg_index)}; + } + + return {}; + } + + return {}; +} + +struct ScaleMergeInfo { + // the edge from the base node to the original node + Node::EdgeConstIterator node_to_merge_edge; + // the scale of the original node + float scale; + // the index of the input or output def on the original node + // this def is moved to the fused node + // for a leading scale (scale -> MatMul), it will be the unscaled input + // for a trailing scale (MatMul -> scale), it will be the scaled output + int node_to_merge_def_index; + // the index of the input or output def on the fused node + int fused_node_def_index; +}; + +std::vector GetInputNodeMerges( + Graph& graph, Node& node, + const std::unordered_set& excluded_initializer_names) { + std::vector input_node_merges{}; + for (auto input_edge = node.InputEdgesBegin(); input_edge != node.InputEdgesEnd(); ++input_edge) { + const Node& input_node = input_edge->GetNode(); + + if (input_node.GetExecutionProviderType() != node.GetExecutionProviderType()) continue; + const auto scale_and_index = GetScaleFromNode(graph, input_node, excluded_initializer_names); + if (!scale_and_index.has_value()) continue; + + // assume scale nodes have 2 input defs, so to_scale_index == 1 - scale_index + ORT_ENFORCE(input_node.InputDefs().size() == 2 && scale_and_index.value().second < 2); + const int to_scale_index = 1 - scale_and_index.value().second; + + input_node_merges.push_back( + {input_edge, + scale_and_index.value().first, + to_scale_index, + input_edge->GetDstArgIndex()}); + } + return input_node_merges; +} + +std::vector GetOutputNodeMerges( + Graph& graph, Node& node, + const std::unordered_set& excluded_initializer_names) { + if (!optimizer_utils::CheckOutputEdges(graph, node, 1)) { + return {}; + } + + std::vector output_node_merges{}; + for (auto output_edge = node.OutputEdgesBegin(); output_edge != node.OutputEdgesEnd(); ++output_edge) { + const Node& output_node = output_edge->GetNode(); + + if (output_node.GetExecutionProviderType() != node.GetExecutionProviderType()) continue; + const auto scale_and_index = GetScaleFromNode(graph, output_node, excluded_initializer_names); + if (!scale_and_index.has_value()) continue; + + ORT_ENFORCE(output_node.OutputDefs().size() == 1); + const int scaled_index = 0; + + output_node_merges.push_back( + {output_edge, + scale_and_index.value().first, + scaled_index, + output_edge->GetSrcArgIndex()}); + } + return output_node_merges; +} + +Status ProcessNode( + Graph& graph, Node& node, bool& modified, + const std::unordered_set& excluded_initializer_names) { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9}) && + !graph_utils::IsSupportedOptypeVersionAndDomain(node, "TransposeScaleMatMul", {1}, kMSDomain)) { + return Status::OK(); + } + + const std::vector input_node_merges = GetInputNodeMerges( + graph, node, excluded_initializer_names); + const std::vector output_node_merges = GetOutputNodeMerges( + graph, node, excluded_initializer_names); + + if (input_node_merges.empty() && output_node_merges.empty()) { + return Status::OK(); + } + + NodeAttributes fused_node_attrs = + node.OpType() == "TransposeScaleMatMul" ? node.GetAttributes() : NodeAttributes{}; + + { + ONNX_NAMESPACE::AttributeProto& alpha_attr = fused_node_attrs["alpha"]; + float total_scale = utils::HasFloat(alpha_attr) ? alpha_attr.f() : 1.0f; + auto accumulate_scale = [&total_scale](const ScaleMergeInfo& fusion) { + total_scale *= fusion.scale; + }; + + std::for_each(input_node_merges.begin(), input_node_merges.end(), accumulate_scale); + std::for_each(output_node_merges.begin(), output_node_merges.end(), accumulate_scale); + + alpha_attr = ONNX_NAMESPACE::MakeAttribute("alpha", total_scale); + } + + auto get_mutable_node_to_merge = [&graph](const ScaleMergeInfo& merge) -> Node& { + return *graph.GetNode(merge.node_to_merge_edge->GetNode().Index()); + }; + + std::vector fused_node_inputs = node.MutableInputDefs(); + for (const auto& input_node_merge : input_node_merges) { + Node& input_node = get_mutable_node_to_merge(input_node_merge); + fused_node_inputs[input_node_merge.fused_node_def_index] = + input_node.MutableInputDefs()[input_node_merge.node_to_merge_def_index]; + } + + std::vector fused_node_outputs = node.MutableOutputDefs(); + for (const auto& output_node_merge : output_node_merges) { + Node& output_node = get_mutable_node_to_merge(output_node_merge); + fused_node_outputs[output_node_merge.fused_node_def_index] = + output_node.MutableOutputDefs()[output_node_merge.node_to_merge_def_index]; + } + + Node& matmul_scale_node = graph.AddNode( + graph.GenerateNodeName(node.Name() + "_FusedMatMulAndScale"), + "TransposeScaleMatMul", + "Fused MatMul and Scale", + fused_node_inputs, + fused_node_outputs, + &fused_node_attrs, + kMSDomain); + + matmul_scale_node.SetExecutionProviderType(node.GetExecutionProviderType()); + + { + std::vector> nodes_to_remove{node}; + + for (const auto& input_node_merge : input_node_merges) { + // remove merged input node's output edge + auto input_node_edge = input_node_merge.node_to_merge_edge; + Node& input_node = get_mutable_node_to_merge(input_node_merge); + graph.RemoveEdge( + input_node.Index(), node.Index(), + input_node_edge->GetSrcArgIndex(), input_node_edge->GetDstArgIndex()); + + // only remove merged input node if it has no more outputs + if (!optimizer_utils::CheckOutputEdges(graph, input_node, 0)) continue; + + nodes_to_remove.push_back(input_node); + } + + for (const auto& output_node_merge : output_node_merges) { + nodes_to_remove.push_back(get_mutable_node_to_merge(output_node_merge)); + } + + for (Node& node_to_remove : nodes_to_remove) { + graph_utils::RemoveNodeOutputEdges(graph, node_to_remove); + graph.RemoveNode(node_to_remove.Index()); + } + } + + modified = true; + + return Status::OK(); +} + +} // namespace + +Status MatMulScaleFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) + const { + GraphViewer graph_viewer{graph}; + const auto node_indices = graph_viewer.GetNodesInTopologicalOrder(); + for (const auto node_index : node_indices) { + auto* node = graph.GetNode(node_index); + if (!node) continue; + + ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger)); + + ORT_RETURN_IF_ERROR(ProcessNode(graph, *node, modified, excluded_initializer_names_)); + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/matmul_scale_fusion.h b/onnxruntime/core/optimizer/matmul_scale_fusion.h new file mode 100644 index 0000000000..a2dd853f69 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_scale_fusion.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** + * Fuses MatMul with surrounding scales (multiplies or divides) by a constant + * scalar into TransposeScaleMatMul. + * + * For example, given matrices A and B and constant scalars t, u, and v: + * Mul(v, MatMul(Mul(t, A), Mul(u, B)) + * -> TransposeScaleMatMul(A, B, alpha=t*u*v) + */ +class MatMulScaleFusion : public GraphTransformer { + public: + /** + * Constructor. + * @param compatible_execution_providers The compatible execution providers. + * @param excluded_initializer_names Fusion will be skipped on scales by any + * of the named initializers. + */ + MatMulScaleFusion( + const std::unordered_set& compatible_execution_providers = {}, + const std::unordered_set& excluded_initializer_names = {}) + : GraphTransformer{"MatMulScaleFusion", compatible_execution_providers}, + excluded_initializer_names_{excluded_initializer_names} { + } + + private: + Status ApplyImpl( + Graph& graph, bool& modified, + int graph_level, const logging::Logger& logger) const override; + + const std::unordered_set excluded_initializer_names_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc index b78138ea6b..7f7d66bc5d 100644 --- a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc @@ -71,7 +71,7 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); if ((!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9}) && - !graph_utils::IsSupportedOptypeVersionAndDomain(node, "TransposeMatMul", {1}, kMSDomain)) || + !graph_utils::IsSupportedOptypeVersionAndDomain(node, "TransposeScaleMatMul", {1}, kMSDomain)) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { continue; } @@ -104,16 +104,16 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ const std::vector output_defs{node.MutableOutputDefs()[0]}; Node& matmul_node = graph.AddNode(graph.GenerateNodeName("MatMul_With_Transpose"), - "TransposeMatMul", + "TransposeScaleMatMul", "fused MatMul and Transpose ", input_defs, output_defs, {}, kMSDomain); bool transpose_left = (left != nullptr); - if (node.OpType() == "TransposeMatMul") { + if (node.OpType() == "TransposeScaleMatMul") { transpose_left ^= static_cast(node.GetAttributes().at("transA").i()); } bool transpose_right = (right != nullptr); - if (node.OpType() == "TransposeMatMul") { + if (node.OpType() == "TransposeScaleMatMul") { transpose_right ^= static_cast(node.GetAttributes().at("transB").i()); } matmul_node.AddAttribute("transA", static_cast(transpose_left)); diff --git a/onnxruntime/core/providers/cuda/math/matmul.cc b/onnxruntime/core/providers/cuda/math/matmul.cc index bec6cdb364..307dbb0e15 100644 --- a/onnxruntime/core/providers/cuda/math/matmul.cc +++ b/onnxruntime/core/providers/cuda/math/matmul.cc @@ -96,8 +96,8 @@ Status MatMul::ComputeInternal(OpKernelContext* ctx) const { if (Y->Shape().Size() == 0) return Status::OK(); - CudaT one = ToCudaType::FromFloat(1.0f); - CudaT zero = ToCudaType::FromFloat(0.0f); + const CudaT alpha = ToCudaType::FromFloat(alpha_); + const CudaT zero = ToCudaType::FromFloat(0.0f); cublasOperation_t transA = transa ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t transB = transb ? CUBLAS_OP_T : CUBLAS_OP_N; @@ -114,7 +114,7 @@ Status MatMul::ComputeInternal(OpKernelContext* ctx) const { static_cast(helper.N()), static_cast(helper.M()), static_cast(helper.K()), - &one, + &alpha, reinterpret_cast(right_X->template Data()), ldb, reinterpret_cast(left_X->template Data()), @@ -132,7 +132,7 @@ Status MatMul::ComputeInternal(OpKernelContext* ctx) const { static_cast(helper.N()), static_cast(helper.M()), static_cast(helper.K()), - &one, + &alpha, reinterpret_cast(right_X->template Data()), ldb, stride_B, @@ -167,7 +167,7 @@ Status MatMul::ComputeInternal(OpKernelContext* ctx) const { static_cast(helper.N()), static_cast(helper.M()), static_cast(helper.K()), - &one, + &alpha, right_arrays.GpuPtr(), ldb, left_arrays.GpuPtr(), diff --git a/onnxruntime/core/providers/cuda/math/matmul.h b/onnxruntime/core/providers/cuda/math/matmul.h index f19ee7d9cc..d9c12abbf7 100644 --- a/onnxruntime/core/providers/cuda/math/matmul.h +++ b/onnxruntime/core/providers/cuda/math/matmul.h @@ -13,16 +13,18 @@ class MatMul final : public CudaKernel { public: MatMul(const OpKernelInfo& info) - : CudaKernel(info) { - trans_A_ = info.GetAttrOrDefault("transA", 0); - trans_B_ = info.GetAttrOrDefault("transB", 0); + : CudaKernel(info), + alpha_{info.GetAttrOrDefault("alpha", 1.0f)}, + trans_A_{info.GetAttrOrDefault("transA", 0) != 0}, + trans_B_{info.GetAttrOrDefault("transB", 0) != 0} { } Status ComputeInternal(OpKernelContext* context) const override; private: - bool trans_A_; - bool trans_B_; + const float alpha_; + const bool trans_A_; + const bool trans_B_; }; } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/transpose_matmul_op_test.cc b/onnxruntime/test/contrib_ops/transpose_scale_matmul_op_test.cc similarity index 83% rename from onnxruntime/test/contrib_ops/transpose_matmul_op_test.cc rename to onnxruntime/test/contrib_ops/transpose_scale_matmul_op_test.cc index afe4f4f00d..d76192600c 100644 --- a/onnxruntime/test/contrib_ops/transpose_matmul_op_test.cc +++ b/onnxruntime/test/contrib_ops/transpose_scale_matmul_op_test.cc @@ -131,10 +131,10 @@ void ProcessInputs(const std::vector& input_dims, const std::vector& } template -void RunTransposeMatMulTest(int32_t opset_version = 7, bool transa = false, bool transb = false) { +void RunTransposeScaleMatMulTest(int32_t opset_version = 7, bool transa = false, bool transb = false, float alpha = 1.0f) { std::vector common_input_vals{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; for (auto t : GenerateSimpleTestCases()) { - OpTester test("TransposeMatMul", opset_version, onnxruntime::kMSDomain); + OpTester test("TransposeScaleMatMul", opset_version, onnxruntime::kMSDomain); std::vector input0_dims(t.input0_dims); std::vector input0_vals; @@ -150,6 +150,13 @@ void RunTransposeMatMulTest(int32_t opset_version = 7, bool transa = false, bool test.AddAttribute("transA", (int64_t)transa); test.AddAttribute("transB", (int64_t)transb); + test.AddAttribute("alpha", alpha); + + if (alpha != 1.0f) { + std::transform( + t.expected_vals.begin(), t.expected_vals.end(), t.expected_vals.begin(), + [alpha](const T& val) -> T { return alpha * val; }); + } test.AddOutput("Y", t.expected_dims, t.expected_vals); @@ -159,25 +166,31 @@ void RunTransposeMatMulTest(int32_t opset_version = 7, bool transa = false, bool } TEST(TransposeMatMulOpTest, FloatTypeNoTranspose) { - RunTransposeMatMulTest(1); + RunTransposeScaleMatMulTest(1); } #ifdef USE_CUDA // double support only implemented in CUDA kernel TEST(TransposeMatMulOpTest, DoubleTypeNoTranspose) { - RunTransposeMatMulTest(1); + RunTransposeScaleMatMulTest(1); } #endif TEST(TransposeMatMulOpTest, FloatTypeTransposeA) { - RunTransposeMatMulTest(1, true, false); + RunTransposeScaleMatMulTest(1, true, false); } TEST(TransposeMatMulOpTest, FloatTypeTransposeB) { - RunTransposeMatMulTest(1, false, true); + RunTransposeScaleMatMulTest(1, false, true); } TEST(TransposeMatMulOpTest, FloatTypeTransposeAB) { - RunTransposeMatMulTest(1, true, true); + RunTransposeScaleMatMulTest(1, true, true); +} + +TEST(TransposeMatMulOpTest, FloatTypeScale) { + RunTransposeScaleMatMulTest(1, false, false, 0.5f); + RunTransposeScaleMatMulTest(1, true, false, 2.0f); + RunTransposeScaleMatMulTest(1, true, true, 4.0f); } } // namespace transpose_matmul diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index b6ccbcd1b8..c2b8698d55 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -37,6 +37,7 @@ #include "core/optimizer/initializer.h" #include "core/optimizer/layer_norm_fusion.h" #include "core/optimizer/matmul_add_fusion.h" +#include "core/optimizer/matmul_scale_fusion.h" #include "core/optimizer/matmul_transpose_fusion.h" #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/reshape_fusion.h" @@ -702,7 +703,7 @@ TEST_F(GraphTransformationTests, TransposeMatmulFusion) { std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Transpose"] == 0); ASSERT_TRUE(op_to_count["MatMul"] == 0); - ASSERT_TRUE(op_to_count["TransposeMatMul"] == 1); + ASSERT_TRUE(op_to_count["TransposeScaleMatMul"] == 1); } TEST_F(GraphTransformationTests, TransposeMatmulFusionOnTwoTranspose) { @@ -719,10 +720,10 @@ TEST_F(GraphTransformationTests, TransposeMatmulFusionOnTwoTranspose) { std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Transpose"] == 0); ASSERT_TRUE(op_to_count["MatMul"] == 0); - ASSERT_TRUE(op_to_count["TransposeMatMul"] == 1); + ASSERT_TRUE(op_to_count["TransposeScaleMatMul"] == 1); auto& node = *graph.Nodes().begin(); - ASSERT_TRUE(node.OpType() == "TransposeMatMul"); + ASSERT_TRUE(node.OpType() == "TransposeScaleMatMul"); ASSERT_TRUE(static_cast(node.GetAttributes().at("transA").i())); ASSERT_TRUE(static_cast(node.GetAttributes().at("transB").i())); } @@ -741,10 +742,10 @@ TEST_F(GraphTransformationTests, TransposeMatmulFusionOnThreeTranspose) { std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Transpose"] == 0); ASSERT_TRUE(op_to_count["MatMul"] == 0); - ASSERT_TRUE(op_to_count["TransposeMatMul"] == 1); + ASSERT_TRUE(op_to_count["TransposeScaleMatMul"] == 1); auto& node = *graph.Nodes().begin(); - ASSERT_TRUE(node.OpType() == "TransposeMatMul"); + ASSERT_TRUE(node.OpType() == "TransposeScaleMatMul"); ASSERT_FALSE(static_cast(node.GetAttributes().at("transA").i())); ASSERT_TRUE(static_cast(node.GetAttributes().at("transB").i())); } @@ -763,7 +764,7 @@ TEST_F(GraphTransformationTests, TransposeMatmulNoFusionOnInvalidPerm) { std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Transpose"] == 1); ASSERT_TRUE(op_to_count["MatMul"] == 1); - ASSERT_TRUE(op_to_count["TransposeMatMul"] == 0); + ASSERT_TRUE(op_to_count["TransposeScaleMatMul"] == 0); } TEST_F(GraphTransformationTests, Gemm_LeakyRelu_Fusion) { @@ -2824,5 +2825,111 @@ TEST_F(GraphTransformationTests, ComputationReductionTransformer_GatherND_E2E) { } #endif +#ifndef DISABLE_CONTRIB_OPS +template +static void TestMatMulScaleFusion( + const PathString& model_path, const Logger& logger, + GraphTransformationCheckFn graph_transformation_check, + const std::unordered_set& excluded_initializer_names = {}) { + SCOPED_TRACE(ORT_TSTR("model path: ") + model_path); + + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_path, model, nullptr, logger)); + Graph& graph = model->MainGraph(); + + auto original_op_counts = CountOpsInGraph(graph); + + onnxruntime::GraphTransformerManager graph_transformer_manager{5}; + ASSERT_STATUS_OK(graph_transformer_manager.Register( + make_unique(std::unordered_set{}, excluded_initializer_names), + TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformer_manager.ApplyTransformers(graph, TransformerLevel::Level2, logger)); + + auto transformed_op_counts = CountOpsInGraph(graph); + + graph_transformation_check(graph, original_op_counts, transformed_op_counts); +} + +TEST_F(GraphTransformationTests, MatMulScaleFusionFusableModels) { + const std::vector one_fusion_model_paths{ + MODEL_FOLDER "fusion/matmul_scale_in0.onnx", + MODEL_FOLDER "fusion/matmul_scale_in0_in1.onnx", + MODEL_FOLDER "fusion/matmul_scale_in0_in1_out.onnx", + MODEL_FOLDER "fusion/matmul_scale_transposescalematmul_in0_in1_out.onnx", + }; + + for (const auto& path : one_fusion_model_paths) { + TestMatMulScaleFusion( + path, *logger_, + [](const Graph& graph, + std::map original_op_counts, + std::map transformed_op_counts) { + EXPECT_EQ(transformed_op_counts["Mul"], 0); + EXPECT_EQ(transformed_op_counts["Div"], 0); + EXPECT_EQ(transformed_op_counts["MatMul"], 0); + EXPECT_EQ(transformed_op_counts["TransposeScaleMatMul"], 1); + + // check combined scale, individual scales should all have the same value + const float scale_value = 3.0f; + + const int num_scales = + original_op_counts["Mul"] + original_op_counts["Div"] + original_op_counts["TransposeScaleMatMul"]; + + auto fused_node = std::find_if( + graph.Nodes().cbegin(), graph.Nodes().cend(), + [](const Node& node) { return node.OpType() == "TransposeScaleMatMul"; }); + ASSERT_NE(fused_node, graph.Nodes().cend()); + + auto alpha_attr = fused_node->GetAttributes().find("alpha"); + ASSERT_NE(alpha_attr, fused_node->GetAttributes().end()); + + EXPECT_EQ(alpha_attr->second.f(), pow(scale_value, num_scales)); + }); + } +} + +TEST_F(GraphTransformationTests, MatMulScaleFusionUnfusableModels) { + const std::vector unfusable_model_paths{ + MODEL_FOLDER "fusion/matmul_scale_unfusable_div_not_scale.onnx", + MODEL_FOLDER "fusion/matmul_scale_unfusable_scale_not_scalar.onnx", + MODEL_FOLDER "fusion/matmul_scale_unfusable_scale_not_constant.onnx", + }; + + for (const auto& path : unfusable_model_paths) { + TestMatMulScaleFusion( + path, *logger_, + [](const Graph&, + const std::map& original_op_counts, + const std::map& transformed_op_counts) { + EXPECT_EQ(original_op_counts, transformed_op_counts); + }); + } +} + +TEST_F(GraphTransformationTests, MatMulScaleFusionReusedInputScale) { + TestMatMulScaleFusion( + MODEL_FOLDER "fusion/matmul_scale_reused_input_scale.onnx", *logger_, + [](const Graph&, + const std::map&, + std::map transformed_op_counts) { + EXPECT_EQ(transformed_op_counts["Mul"], 0); + EXPECT_EQ(transformed_op_counts["Div"], 0); + EXPECT_EQ(transformed_op_counts["MatMul"], 0); + EXPECT_EQ(transformed_op_counts["TransposeScaleMatMul"], 2); + }); +} + +TEST_F(GraphTransformationTests, MatMulScaleFusionExcludedInitializerName) { + TestMatMulScaleFusion( + MODEL_FOLDER "fusion/matmul_scale_in0.onnx", *logger_, + [](const Graph&, + const std::map& original_op_counts, + const std::map& transformed_op_counts) { + EXPECT_EQ(original_op_counts, transformed_op_counts); + }, + {"scale"}); +} +#endif + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_gen.py b/onnxruntime/test/testdata/transform/fusion/matmul_scale_gen.py new file mode 100644 index 0000000000..7e63930f38 --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/matmul_scale_gen.py @@ -0,0 +1,186 @@ +import onnx +from onnx import helper +from onnx import TensorProto +from onnx import OperatorSetIdProto + + +onnxdomain = OperatorSetIdProto() +onnxdomain.version = 12 +# The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +onnxdomain.domain = "" +msdomain = OperatorSetIdProto() +msdomain.version = 1 +msdomain.domain = "com.microsoft" +opsets = [onnxdomain, msdomain] + +scale_value = 3.0 + + +def save(model_path, nodes, inputs, outputs, initializers): + graph = helper.make_graph( + nodes, + "MatMulScaleTest", + inputs, outputs, initializers) + + model = helper.make_model( + graph, opset_imports=opsets, producer_name="onnxruntime-test") + + onnx.save(model, model_path) + + +def gen(model_path, + use_transpose_matmul, + scale_input_0, scale_input_1, scale_output): + matmul_op = "TransposeScaleMatMul" if use_transpose_matmul else "MatMul" + matmul_domain = "com.microsoft" if use_transpose_matmul else "" + matmul_attrs = {"alpha": scale_value} if use_transpose_matmul else {} + + nodes = [] + + if scale_input_0: + nodes.append(helper.make_node( + "Mul", ["input_0", "scale"], ["scaled_input_0"], "scale input_0")) + + if scale_input_1: + nodes.append(helper.make_node( + "Div", ["input_1", "scale_reciprocal"], ["scaled_input_1"], "scale input_1")) + + nodes.append(helper.make_node( + matmul_op, + [ + "scaled_input_0" if scale_input_0 else "input_0", + "scaled_input_1" if scale_input_1 else "input_1" + ], + [ + "unscaled_output" if scale_output else "output" + ], + matmul_op, + "", + matmul_domain, + **matmul_attrs)) + + if scale_output: + nodes.append(helper.make_node( + "Mul", ["scale", "unscaled_output"], ["output"], "scale output")) + + initializers = [ + helper.make_tensor("scale", TensorProto.FLOAT, [], [scale_value]), + helper.make_tensor("scale_reciprocal", + TensorProto.FLOAT, [], [1/scale_value]) + ] + + inputs = [ + helper.make_tensor_value_info( + "input_0", TensorProto.FLOAT, [2, 'M', 'K']), + helper.make_tensor_value_info( + "input_1", TensorProto.FLOAT, [2, 'K', 'N']) + ] + + outputs = [ + helper.make_tensor_value_info( + "output", TensorProto.FLOAT, [2, 'M', 'N']) + ] + + save(model_path, nodes, inputs, outputs, initializers) + + +gen("matmul_scale_in0.onnx", False, True, False, False) +gen("matmul_scale_in0_in1.onnx", False, True, True, False) +gen("matmul_scale_in0_in1_out.onnx", False, True, True, True) +gen("matmul_scale_transposescalematmul_in0_in1_out.onnx", True, True, True, True) + + +UNFUSABLE_DIV_NOT_SCALE = 0 +UNFUSABLE_SCALE_NOT_SCALAR = 1 +UNFUSABLE_SCALE_NOT_CONSTANT = 2 + + +def gen_unfusable(model_path, unfusable_type): + matmul_op = "MatMul" + + if unfusable_type == UNFUSABLE_DIV_NOT_SCALE: + scale_node = helper.make_node( + "Div", ["scale", "input_0"], ["scaled_input_0"], "scale input_0") + elif unfusable_type == UNFUSABLE_SCALE_NOT_SCALAR: + scale_node = helper.make_node( + "Mul", ["scale_non_scalar", "input_0"], ["scaled_input_0"], "scale input_0") + elif unfusable_type == UNFUSABLE_SCALE_NOT_CONSTANT: + scale_node = helper.make_node( + "Mul", ["input_0", "input_0"], ["scaled_input_0"], "scale input_0") + else: + raise ValueError("Invalid unfusable_type: {}".format(unfusable_type)) + + nodes = [ + scale_node, + helper.make_node( + matmul_op, ["scaled_input_0", "input_1"], ["output"], matmul_op) + ] + + initializers = [ + helper.make_tensor("scale", TensorProto.FLOAT, [], [scale_value]), + helper.make_tensor("scale_non_scalar", TensorProto.FLOAT, + [2, 1, 1], [scale_value, scale_value]) + ] + + inputs = [ + helper.make_tensor_value_info( + "input_0", TensorProto.FLOAT, [2, 'M', 'K']), + helper.make_tensor_value_info( + "input_1", TensorProto.FLOAT, [2, 'K', 'N']) + ] + + outputs = [ + helper.make_tensor_value_info( + "output", TensorProto.FLOAT, [2, 'M', 'N']) + ] + + save(model_path, nodes, inputs, outputs, initializers) + + +gen_unfusable("matmul_scale_unfusable_div_not_scale.onnx", + UNFUSABLE_DIV_NOT_SCALE) +gen_unfusable("matmul_scale_unfusable_scale_not_scalar.onnx", + UNFUSABLE_SCALE_NOT_SCALAR) +gen_unfusable("matmul_scale_unfusable_scale_not_constant.onnx", + UNFUSABLE_SCALE_NOT_CONSTANT) + + +def gen_reused_input_scale(model_path): + matmul_op = "MatMul" + + nodes = [ + helper.make_node( + "Mul", ["input_0", "scale"], ["scaled_input_0"], + "scale input_0"), + helper.make_node( + matmul_op, ["scaled_input_0", "input_1"], ["output_0"], + "MatMul input_0 and input_1"), + helper.make_node( + matmul_op, ["scaled_input_0", "input_2"], ["output_1"], + "MatMul input_0 and input_2") + ] + + initializers = [ + helper.make_tensor("scale", TensorProto.FLOAT, [], [scale_value]) + ] + + inputs = [ + helper.make_tensor_value_info( + "input_0", TensorProto.FLOAT, [2, 'M', 'K']), + helper.make_tensor_value_info( + "input_1", TensorProto.FLOAT, [2, 'K', 'N']), + helper.make_tensor_value_info( + "input_2", TensorProto.FLOAT, [2, 'K', 'N']) + ] + + outputs = [ + helper.make_tensor_value_info( + "output_0", TensorProto.FLOAT, [2, 'M', 'N']), + helper.make_tensor_value_info( + "output_1", TensorProto.FLOAT, [2, 'M', 'N']) + ] + + save(model_path, nodes, inputs, outputs, initializers) + + +gen_reused_input_scale("matmul_scale_reused_input_scale.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_in0.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_scale_in0.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e067d260d791ed0117be53cfafe8102ba3f94565 GIT binary patch literal 315 zcmd;J6B5YJ%d03V%`3^wP1P+)EiSR@X5up8V$aMgC@qOM;9@OKPRvOa;saAD@latY zUNBz)%2i_aEzRLF=0Z~iGsRGdEx)t`Xr2_CZ(<2hsS=d7Vi4kohy;Ud2?1KK#V^39 z#KOSf;NS$YT1yHfuzJ-hJ0}6KrSV0n$(aR3`9K{}@-Qn5ghaT6I2eWaxR^MYxR`|) zeL;jbLV+Q=0&kE4za%*>sN2xx!3Ov_v2Za6@Hh!_@h0cz>g8r87v&e{r zl7~6nKuCm3h=WmxkBf@Jo{8g1QY|9&CW06AKrE0FRR( V7jJTYu3m0ta#4P9ep-nDBLL;wX)^!- literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_in0_in1_out.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_scale_in0_in1_out.onnx new file mode 100644 index 0000000000000000000000000000000000000000..2a8ce8ebf4cc28108b702b0e7ad37220aa83bd9b GIT binary patch literal 449 zcmd;J6B5YJ%d03V%`3^wP1P+)EiSQ|$joKJ#h#g0P+Agiz{OgeoS2g;#0REQ;-SJ) zykNcpl&i$-TbjdV57TAHB>Rhru>@$h5|p-L;4+3fnG0R55F1nv55(;d9@yzZ{19V;L4FMZ27(s9 z0HYEM1A~Ku6C?n%q(B0zSFN&h!th&^Jj~+;LLyv39E?JITudBHT+Bj@z97OIp}-Jb sfj3BjUy>Xb)N$zYU<3S|ShyGjc$@^ec$4#U^>Q?t63_Z6zKv+O`eqno~i34)nk9E@TG=< zoyZdI55FSt)N{&khFzgiKzPixw24DOAUHqq&r;9^n3p@wxAUuPCGXNW-q&wEX3z507?iM!B=b+7* S?}}g(55e*#@hKc<_0SLKo|8-f literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_div_not_scale.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_div_not_scale.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d5894c311cd6620fe2fda34727772202e46d50f3 GIT binary patch literal 323 zcmZ9G!3u&v5QbeXTZXc2>Exk8It4{Sr%oc>!Y*~{B4!3c?uM?>d-TG&s}gvbnQ!Kw z`3IYnaK=7)$wZnZYmpRU|J7j!{UVN@2@F#9Dn+=Zmg0NZusu?|XwWeZ=^I?=sM>Xx zn!FU%A=?Zh5tL7-$#h)vS(;rc^VOIjr>AM2=hq*smRQFKqpFQ%TaAb^4;g2ntfKrn zZ>_d;3KS!1fsS=BC<p+A99%(G3F#yKzeLlE=> DEyGUx literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_scale_not_constant.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_scale_not_constant.onnx new file mode 100644 index 0000000000000000000000000000000000000000..8cd50705ec506afbb46c1f7d19aef5e15d99b5ef GIT binary patch literal 325 zcmZ9GF%N<;6oe@#HZKababi$IoJ=&ZI5-hjgNsfs1Q8PwX$_RbKjQzDR?vi{efRF` zU9m|CXY5thOyotj5m_bnZyk0p%Gt9PvG*IORi()zyTz^HOq5=>(lw&`0Y3Do?zvA* zUW+EtF{4B@sN3Gu{_2wqKR5Mjn;>_fX+aPURl|CPb&N1-)>y8UMU)jY&SE(w-T zd2|L8BWi(;bucK3JTG%^mSFiCl_E%LPF6~7CDJ|#L& A!vFvP literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_scale_not_scalar.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_scale_not_scalar.onnx new file mode 100644 index 0000000000000000000000000000000000000000..bd181311fb1fa958cc7df15770431e845933d363 GIT binary patch literal 334 zcmZXO!3%;g7{#5IjgQhbI(ev&PC=p9E{WH$OP#uinSqd7=!E{7{?^=50=sN`-}~)* z51W*5#$IL3L|$YYkyT>%)?p8%N|Q$xGtOf9N=g{z>{*N0qn5I6ajTt)(#ux5MpQq* zhaS~^_o>Ni(JXSzC=m_nwl}rE`XqDDO^>xrkUP+{AP9#lVZFjSMi@0?EXT_F>v>sp zQ68NE#fVyFj_ literal 0 HcmV?d00001 diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 48b9635f43..b558e49246 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -254,7 +254,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetMatMulGradient) { if (IsGradientRequiredForSrcNodeInput(0)) { ArgDef pre_reduce_grad_0 = IA("PreReduceGrad0"); result.push_back( - NodeDef(OpDef{"TransposeMatMul", kMSDomain, 1}, + NodeDef(OpDef{"TransposeScaleMatMul", kMSDomain, 1}, {GO(0), B}, {pre_reduce_grad_0}, {{"transB", MakeAttribute("transB", int64_t(1))}})); @@ -267,7 +267,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetMatMulGradient) { if (IsGradientRequiredForSrcNodeInput(1)) { ArgDef pre_reduce_grad_1 = IA("PreReduceGrad1"); result.push_back( - NodeDef(OpDef{"TransposeMatMul", kMSDomain, 1}, + NodeDef(OpDef{"TransposeScaleMatMul", kMSDomain, 1}, {A, GO(0)}, {pre_reduce_grad_1}, {{"transA", MakeAttribute("transA", int64_t(1))}})); diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 4b31d9cfc3..9424a78c26 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -3,51 +3,53 @@ #include "orttraining/core/optimizer/graph_transformer_utils.h" +#include "core/mlas/inc/mlas.h" +#include "core/optimizer/bias_gelu_fusion.h" +#include "core/optimizer/cast_elimination.h" +#include "core/optimizer/computation_reduction.h" +#include "core/optimizer/constant_folding.h" +#include "core/optimizer/conv_activation_fusion.h" +#include "core/optimizer/conv_add_fusion.h" +#include "core/optimizer/conv_bn_fusion.h" +#include "core/optimizer/conv_mul_fusion.h" +#include "core/optimizer/dropout_elimination.h" +#include "core/optimizer/embed_layer_norm_fusion.h" +#include "core/optimizer/expand_elimination.h" +#include "core/optimizer/fast_gelu_fusion.h" +#include "core/optimizer/free_dim_override_transformer.h" +#include "core/optimizer/gelu_approximation.h" +#include "core/optimizer/gelu_fusion.h" +#include "core/optimizer/gemm_activation_fusion.h" +#include "core/optimizer/graph_transformer_utils.h" +#include "core/optimizer/identity_elimination.h" +#include "core/optimizer/layer_norm_fusion.h" +#include "core/optimizer/matmul_add_fusion.h" +#include "core/optimizer/matmul_scale_fusion.h" +#include "core/optimizer/matmul_transpose_fusion.h" +#include "core/optimizer/nchwc_transformer.h" +#include "core/optimizer/relu_clip_fusion.h" +#include "core/optimizer/reshape_fusion.h" +#include "core/optimizer/rule_based_graph_transformer.h" +#include "core/optimizer/shape_to_initializer.h" +#include "core/optimizer/skip_layer_norm_fusion.h" +#include "core/optimizer/slice_elimination.h" +#include "core/optimizer/unsqueeze_elimination.h" +#include "core/session/inference_session.h" #include "orttraining/core/framework/distributed_run_context.h" +#include "orttraining/core/optimizer/bias_dropout_fusion.h" #include "orttraining/core/optimizer/insert_output_rewriter.h" #include "orttraining/core/optimizer/megatron_transformer.h" -#include "orttraining/core/optimizer/bias_dropout_fusion.h" #include "orttraining/core/optimizer/nonzero_shape_setter.h" -#include "core/optimizer/identity_elimination.h" -#include "core/optimizer/slice_elimination.h" -#include "core/optimizer/conv_mul_fusion.h" -#include "core/optimizer/conv_bn_fusion.h" -#include "core/optimizer/conv_add_fusion.h" -#include "core/optimizer/constant_folding.h" -#include "core/optimizer/unsqueeze_elimination.h" -#include "core/optimizer/expand_elimination.h" -#include "core/optimizer/cast_elimination.h" -#include "core/optimizer/rule_based_graph_transformer.h" -#include "core/optimizer/conv_activation_fusion.h" -#include "core/optimizer/gemm_activation_fusion.h" -#include "core/optimizer/matmul_add_fusion.h" -#include "core/optimizer/dropout_elimination.h" -#include "core/optimizer/relu_clip_fusion.h" -#include "core/optimizer/shape_to_initializer.h" -#include "core/optimizer/nchwc_transformer.h" -#include "core/optimizer/free_dim_override_transformer.h" -#include "core/optimizer/gelu_fusion.h" -#include "core/optimizer/layer_norm_fusion.h" -#include "core/optimizer/skip_layer_norm_fusion.h" -#include "core/optimizer/embed_layer_norm_fusion.h" -#include "core/optimizer/reshape_fusion.h" -#include "core/optimizer/matmul_transpose_fusion.h" -#include "core/optimizer/bias_gelu_fusion.h" -#include "core/optimizer/fast_gelu_fusion.h" -#include "core/optimizer/gelu_approximation.h" -#include "core/optimizer/graph_transformer_utils.h" -#include "core/optimizer/computation_reduction.h" -#include "core/mlas/inc/mlas.h" -#include "core/session/inference_session.h" namespace onnxruntime { namespace training { namespace transformer_utils { -std::vector> GeneratePreTrainingTransformers(TransformerLevel level, - const std::unordered_set& weights_to_train, - bool enable_gelu_approximation, - const std::vector& transformers_and_rules_to_enable) { +std::vector> GeneratePreTrainingTransformers( + TransformerLevel level, + const std::unordered_set& weights_to_train, + bool enable_gelu_approximation, + const std::vector& transformers_and_rules_to_enable) { std::vector> transformers; std::unique_ptr rule_transformer = nullptr; @@ -130,9 +132,11 @@ std::vector> GeneratePreTrainingTransformers(T return filtered_list; } -std::vector> GenerateTransformers(TransformerLevel level, - gsl::span free_dimension_overrides, - const std::vector& transformers_and_rules_to_enable) { +std::vector> GenerateTransformers( + TransformerLevel level, + const std::unordered_set& weights_to_train, + gsl::span free_dimension_overrides, + const std::vector& transformers_and_rules_to_enable) { std::vector> transformers; std::unique_ptr rule_transformer = nullptr; switch (level) { @@ -146,6 +150,7 @@ std::vector> GenerateTransformers(TransformerL transformers.emplace_back(onnxruntime::make_unique(free_dimension_overrides)); transformers.emplace_back(onnxruntime::make_unique(l1_execution_providers)); transformers.emplace_back(onnxruntime::make_unique(l1_execution_providers)); + transformers.emplace_back(onnxruntime::make_unique(l1_execution_providers, weights_to_train)); rule_transformer = optimizer_utils::GenerateRuleBasedGraphTransformer(level, transformers_and_rules_to_enable, l1_execution_providers); } break; diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.h b/orttraining/orttraining/core/optimizer/graph_transformer_utils.h index 45dceca7ab..87a107e5df 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.h +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.h @@ -14,17 +14,20 @@ namespace training { namespace transformer_utils { /** Generates all pre-training transformers for this level. */ -std::vector> GeneratePreTrainingTransformers(TransformerLevel level, - const std::unordered_set& weights_to_train, - bool enable_gelu_approximation, - const std::vector& rules_and_transformers_to_enable = {}); +std::vector> GeneratePreTrainingTransformers( + TransformerLevel level, + const std::unordered_set& weights_to_train, + bool enable_gelu_approximation, + const std::vector& rules_and_transformers_to_enable = {}); /** Generates all predefined (both rule-based and non-rule-based) transformers for this level. If transformers_and_rules_to_enable is not empty, it returns the intersection between the predefined transformers/rules and the transformers_and_rules_to_enable. */ -std::vector> GenerateTransformers(TransformerLevel level, - gsl::span free_dimension_overrides, - const std::vector& rules_and_transformers_to_enable = {}); +std::vector> GenerateTransformers( + TransformerLevel level, + const std::unordered_set& weights_to_train, + gsl::span free_dimension_overrides, + const std::vector& rules_and_transformers_to_enable = {}); } // namespace transformer_utils } // namespace training diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index be9f47ab4e..9d708d263b 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -109,16 +109,16 @@ bool IsRootNode(const TrainingSession::TrainingConfiguration& config) { } } // namespace - void TrainingSession::FilterUnusedWeights(const std::unordered_set& weight_names_to_train, - std::unordered_set& filtered_weight_names_to_train) { + std::unordered_set& filtered_weight_names_to_train) { filtered_weight_names_to_train.clear(); - for (const auto& name: weight_names_to_train) { + for (const auto& name : weight_names_to_train) { auto nodes = model_->MainGraph().GetConsumerNodes(name); if (!nodes.empty()) filtered_weight_names_to_train.insert(name); else - LOGS(*session_logger_, WARNING) << "Couldn't find any consumer node for weight " << name << ", exclude it from training."; + LOGS(*session_logger_, WARNING) + << "Couldn't find any consumer node for weight " << name << ", exclude it from training."; } } @@ -138,7 +138,8 @@ Status TrainingSession::ConfigureForTraining( TrainingConfigurationResult config_result{}; ORT_ENFORCE(config.distributed_config.pipeline_parallel_size > 0, - "This parameter should be 1 if there is no pipelie parallelism. Otherwise, it's the number of pipeline stages."); + "This parameter should be 1 if there is no pipeline parallelism. " + "Otherwise, it's the number of pipeline stages."); DistributedRunContext::CreateInstance({config.distributed_config.world_rank, config.distributed_config.world_size, @@ -272,8 +273,9 @@ Status TrainingSession::ConfigureForTraining( } pipeline_result.fetch_names.push_back(name); } - pipeline_result.pipeline_stage_id = config.distributed_config.world_rank / - (config.distributed_config.data_parallel_size * config.distributed_config.horizontal_parallel_size); + pipeline_result.pipeline_stage_id = + config.distributed_config.world_rank / + (config.distributed_config.data_parallel_size * config.distributed_config.horizontal_parallel_size); config_result.pipeline_config_result = pipeline_result; } @@ -478,21 +480,24 @@ static Status AddGradientAccumulationNodes(Graph& graph, gradient_accumulation_buffers.resize(gradient_argdefs.size()); std::vector grad_acc_outputs; for (size_t i = 0; i < gradient_argdefs.size(); ++i) { - grad_acc_outputs.push_back(BuildGradientAccumulationNode( - nodearg_name_generator, gradient_argdefs[i], gradient_accumulation_buffers[i], graph_defs, false) - .name); + grad_acc_outputs.push_back( + BuildGradientAccumulationNode( + nodearg_name_generator, gradient_argdefs[i], gradient_accumulation_buffers[i], graph_defs, false) + .name); } return GraphAugmenter::AugmentGraph(graph, graph_defs); } -Status TrainingSession::ApplyTransformationsToMainGraph(const std::unordered_set& weights_to_train, bool enable_gelu_approximation) { +Status TrainingSession::ApplyTransformationsToMainGraph( + const std::unordered_set& weights_to_train, bool enable_gelu_approximation) { GraphTransformerManager graph_transformation_mgr{1}; AddPreTrainingTransformers(graph_transformation_mgr, weights_to_train, enable_gelu_approximation); // apply transformers Graph& graph = model_->MainGraph(); for (int i = static_cast(TransformerLevel::Level1); i <= static_cast(TransformerLevel::MaxLevel); i++) { - ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, static_cast(i), *session_logger_)); + ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers( + graph, static_cast(i), *session_logger_)); } return common::Status::OK(); } @@ -530,7 +535,8 @@ void TrainingSession::AddPredefinedTransformers(GraphTransformerManager& transfo const std::vector& custom_list) { auto add_transformers = [&](TransformerLevel level) { // Generate and register transformers for level - auto transformers_to_register = transformer_utils::GenerateTransformers(level, GetSessionOptions().free_dimension_overrides, custom_list); + auto transformers_to_register = transformer_utils::GenerateTransformers( + level, weights_to_train_, GetSessionOptions().free_dimension_overrides, custom_list); for (auto& entry : transformers_to_register) { transformer_manager.Register(std::move(entry), level); } @@ -636,10 +642,12 @@ Status TrainingSession::ConfigureLossFunction( return DoPostLoadProcessing(*model_); } -Status TrainingSession::EnableMixedPrecision(const std::unordered_set& weights_to_train, - bool use_fp16_initializer, - std::unordered_map& fp32_weight_name_to_fp16_node_arg) { - ORT_RETURN_IF_ERROR(TransformGraphForMixedPrecision(model_->MainGraph(), weights_to_train, use_fp16_initializer, fp32_weight_name_to_fp16_node_arg)); +Status TrainingSession::EnableMixedPrecision( + const std::unordered_set& weights_to_train, + bool use_fp16_initializer, + std::unordered_map& fp32_weight_name_to_fp16_node_arg) { + ORT_RETURN_IF_ERROR(TransformGraphForMixedPrecision( + model_->MainGraph(), weights_to_train, use_fp16_initializer, fp32_weight_name_to_fp16_node_arg)); std::unordered_set fp16_weight_initializer_names{}; std::transform( @@ -797,7 +805,8 @@ Status TrainingSession::Save(const PathString& model_uri, TrainingSession::SaveO auto status = Model::Save(*new_model, model_uri); if (!status.IsOK()) { - LOGS(*session_logger_, WARNING) << "Error when saving model " << ToMBString(model_uri) << " : " << status.ErrorMessage(); + LOGS(*session_logger_, WARNING) + << "Error when saving model " << ToMBString(model_uri) << " : " << status.ErrorMessage(); } return status; @@ -817,8 +826,8 @@ bool TrainingSession::IsGraphOutputFp32Node(const std::string& output_name) cons ORT_ENFORCE(output_producer_node != nullptr, "Output: " + output_name + " is not produced by any node."); for (auto output : output_producer_node->OutputDefs()) { - if (output->Name() == output_name && output->TypeAsProto() != nullptr && output->TypeAsProto()->has_tensor_type() - && output->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + if (output->Name() == output_name && output->TypeAsProto() != nullptr && output->TypeAsProto()->has_tensor_type() && + output->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { return true; } } @@ -835,24 +844,22 @@ common::Status TrainingSession::Run(const RunOptions& run_options, IOBinding& io for (auto& drop_ratio : dropout_eval_feeds_) { OrtValue feed_value; // We allocate on CPU first, copy will be taken care of downstream. - auto cpu_allocator = GetSessionState().GetExecutionProviders() - .Get(onnxruntime::kCpuExecutionProvider) - ->GetAllocator(0, OrtMemTypeDefault); + const auto* cpu_ep = GetSessionState().GetExecutionProviders().Get(onnxruntime::kCpuExecutionProvider); + const auto cpu_allocator = cpu_ep->GetAllocator(0, OrtMemTypeDefault); feed_value = onnxruntime::MakeScalarMLValue(cpu_allocator, 0.f, true /*is_1d*/); // Bind new feed to graph input. new_feeds.emplace_back(drop_ratio, feed_value); } - } - else { + } else { auto& input_names = io_binding.GetInputNames(); - if (GetSessionState().GetInputNodeInfoMap().find(training_mode_string_) != GetSessionState().GetInputNodeInfoMap().end() && + if (GetSessionState().GetInputNodeInfoMap().find(training_mode_string_) != + GetSessionState().GetInputNodeInfoMap().end() && std::find(input_names.begin(), input_names.end(), training_mode_string_) == input_names.end()) { // Set training_mode input to false OrtValue training_mode_feed_value; // We allocate on CPU first, copy will be taken care of downstream. - auto cpu_allocator = GetSessionState().GetExecutionProviders() - .Get(onnxruntime::kCpuExecutionProvider) - ->GetAllocator(0, OrtMemTypeDefault); + const auto* cpu_ep = GetSessionState().GetExecutionProviders().Get(onnxruntime::kCpuExecutionProvider); + const auto cpu_allocator = cpu_ep->GetAllocator(0, OrtMemTypeDefault); training_mode_feed_value = onnxruntime::MakeScalarMLValue(cpu_allocator, false, true /*is_1d*/); new_feeds.emplace_back(training_mode_string_, training_mode_feed_value); } @@ -879,18 +886,18 @@ Status TrainingSession::SetEvalFeedNames() { for (auto& node : graph.Nodes()) { auto it = Nodes_Need_Eval_Feeds.find(node.OpType()); - if(it != Nodes_Need_Eval_Feeds.cend()) { + if (it != Nodes_Need_Eval_Feeds.cend()) { // The opset is < 12, add each ratio input to graph inputs for overriding. // Needs to be removed when TrainableDropout is deprecated. - if(it->compare("TrainableDropout") == 0) { + if (it->compare("TrainableDropout") == 0) { auto& ratio_name = node.InputDefs()[1]->Name(); dropout_eval_feeds_.insert(ratio_name); ORT_ENFORCE(model_->MainGraph().GetProducerNode(ratio_name) == nullptr, - "Input: " + ratio_name + " should not have any producer node."); + "Input: " + ratio_name + " should not have any producer node."); defs.AddGraphInputs({ratio_name}); } // Found an opset-12 dropout node, replace initializer name. - else if(node.InputArgCount().size() > 2) { + else if (node.InputArgCount().size() > 2) { auto& mode_input = node.MutableInputDefs()[2]; const ONNX_NAMESPACE::TensorProto* mode_initializer = nullptr; if (!graph.GetInitializedTensor(training_mode_string_, mode_initializer)) { @@ -910,7 +917,7 @@ Status TrainingSession::SetEvalFeedNames() { } } } - + ORT_RETURN_IF_ERROR(GraphAugmenter::AugmentGraph(graph, defs)); return DoPostLoadProcessing(*model_); } @@ -1052,7 +1059,8 @@ std::unordered_set TrainingSession::GetTrainableModelInitializers( bool proceed = std::any_of(from->InputEdgesBegin(), from->InputEdgesEnd(), is_trainable_from_to_link); if (!proceed && session_logger_) { - VLOGS(*session_logger_, 1) << "Stopping training parameters discovery traversal from " << from->Name() << " to " << to->Name() << std::endl; + VLOGS(*session_logger_, 1) + << "Stopping training parameters discovery traversal from " << from->Name() << " to " << to->Name(); } return !proceed; diff --git a/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc b/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc index 20192c5894..c652908b5b 100644 --- a/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc @@ -22,7 +22,7 @@ TEST(GraphTransformerUtilsTestsForTraining, TestGenerateGraphTransformers) { std::string l2_transformer = "ConvActivationFusion"; std::vector custom_list = {l1_rule1, l1_transformer, l2_transformer}; - auto transformers = training::transformer_utils::GenerateTransformers(TransformerLevel::Level1, {}, custom_list); + auto transformers = training::transformer_utils::GenerateTransformers(TransformerLevel::Level1, {}, {}, custom_list); ASSERT_TRUE(transformers.size() == 1); auto l1_rule_transformer_name = optimizer_utils::GenerateRuleBasedTransformerName(TransformerLevel::Level1); @@ -34,7 +34,7 @@ TEST(GraphTransformerUtilsTestsForTraining, TestGenerateGraphTransformers) { } ASSERT_TRUE(rule_transformer && rule_transformer->RulesCount() == 1); - transformers = training::transformer_utils::GenerateTransformers(TransformerLevel::Level2, {}, custom_list); + transformers = training::transformer_utils::GenerateTransformers(TransformerLevel::Level2, {}, {}, custom_list); #ifndef DISABLE_CONTRIB_OPS ASSERT_TRUE(transformers.size() == 1); #else