Enable MatMul + Scale fusion (#4669)

Update TransposeMatMul to support scaling of the matrix product by a constant scalar value (analogous to the GEMM alpha parameter). Rename TransposeMatMul to TransposeScaleMatMul.
Fuse MatMul with surrounding Mul/Div with constant scalar into TransposeScaleMatMul.
This commit is contained in:
edgchen1 2020-08-04 16:27:22 -07:00 committed by GitHub
parent f9bd52f852
commit 9d7284fc3b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
29 changed files with 796 additions and 139 deletions

View file

@ -7,11 +7,13 @@
namespace onnxruntime {
// Graph transformer level
// refer to docs/ONNX_Runtime_Graph_Optimizations.md for details
enum class TransformerLevel : int {
Default = 0,
Level1,
Level2,
Level3,
Default = 0, // required transformers only
Level1, // basic optimizations
Level2, // extended optimizations
Level3, // layout optimizations
// The max level should always be same as the last level.
MaxLevel = Level3
};

View file

@ -20,7 +20,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordConvEmbedding);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeScaleMatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MurmurHash3);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MaxpoolWithMask);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Pad);
@ -130,7 +130,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordConvEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MurmurHash3)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeScaleMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MaxpoolWithMask)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Unique)>,

View file

@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "transpose_matmul.h"
#include "contrib_ops/cpu/transpose_scale_matmul.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/util/math.h"
@ -9,20 +9,21 @@ namespace onnxruntime {
namespace contrib {
ONNX_OPERATOR_KERNEL_EX(
TransposeMatMul,
TransposeScaleMatMul,
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
TransposeMatMul);
TransposeScaleMatMul);
TransposeMatMul::TransposeMatMul(const OpKernelInfo& info)
TransposeScaleMatMul::TransposeScaleMatMul(const OpKernelInfo& info)
: OpKernel{info} {
ORT_ENFORCE(info.GetAttr("transA", &trans_a_attr_).IsOK());
ORT_ENFORCE(info.GetAttr("transB", &trans_b_attr_).IsOK());
ORT_THROW_IF_ERROR(info.GetAttr("alpha", &alpha_attr_));
ORT_THROW_IF_ERROR(info.GetAttr("transA", &trans_a_attr_));
ORT_THROW_IF_ERROR(info.GetAttr("transB", &trans_b_attr_));
}
Status TransposeMatMul::Compute(OpKernelContext* context) const {
Status TransposeScaleMatMul::Compute(OpKernelContext* context) const {
concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();
const Tensor* A = context->Input<Tensor>(0);
@ -47,7 +48,7 @@ Status TransposeMatMul::Compute(OpKernelContext* context) const {
trans_a ? CblasTrans : CblasNoTrans,
trans_b ? CblasTrans : CblasNoTrans,
helper.M(), helper.N(), helper.K(),
1.0f,
alpha_attr_,
A->Data<float>() + helper.LeftOffsets()[i],
B->Data<float>() + helper.RightOffsets()[i],
0.0f,

View file

@ -8,13 +8,14 @@
namespace onnxruntime {
namespace contrib {
class TransposeMatMul final : public OpKernel {
class TransposeScaleMatMul final : public OpKernel {
public:
TransposeMatMul(const OpKernelInfo& info);
TransposeScaleMatMul(const OpKernelInfo& info);
Status Compute(OpKernelContext* context) const override;
private:
float alpha_attr_;
int64_t trans_a_attr_, trans_b_attr_;
};

View file

@ -17,9 +17,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeScaleMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeScaleMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeScaleMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Rfft);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Rfft);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Rfft);
@ -81,9 +81,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeScaleMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeScaleMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeScaleMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Rfft)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Rfft)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Rfft)>,

View file

@ -9,7 +9,7 @@ namespace cuda {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
TransposeMatMul, \
TransposeScaleMatMul, \
kMSDomain, \
1, \
T, \

View file

@ -1774,17 +1774,22 @@ Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-
ONNX_NAMESPACE::matmulShapeInference(ctx, 0, 1);
});
static const char* TransposeMatMul_doc = R"DOC(
static const char* TransposeScaleMatMul_doc = R"DOC(
Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html
)DOC";
ONNX_CONTRIB_OPERATOR_SCHEMA(TransposeMatMul)
ONNX_CONTRIB_OPERATOR_SCHEMA(TransposeScaleMatMul)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
.SetDoc("TransposeMatMul")
.SetDoc("TransposeScaleMatMul")
.Input(0, "A", "N-dimensional matrix A", "T")
.Input(1, "B", "N-dimensional matrix B", "T")
.Attr(
"alpha",
"Scalar multiplier for the product of the input tensors.",
AttributeProto::FLOAT,
1.0f)
.Attr(
"transA",
"Whether A should be transposed on the last two dimensions before doing multiplication",
@ -1800,7 +1805,7 @@ Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input and output types to float tensors.")
.SetDoc(TransposeMatMul_doc)
.SetDoc(TransposeScaleMatMul_doc)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
auto transAAttr = ctx.getAttribute("transA");

View file

@ -24,6 +24,7 @@
#include "core/optimizer/identity_elimination.h"
#include "core/optimizer/layer_norm_fusion.h"
#include "core/optimizer/matmul_add_fusion.h"
#include "core/optimizer/matmul_scale_fusion.h"
#include "core/optimizer/nchwc_transformer.h"
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
@ -144,6 +145,8 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
transformers.emplace_back(onnxruntime::make_unique<SkipLayerNormFusion>(cpu_cuda_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<FastGeluFusion>(cpu_cuda_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<MatMulScaleFusion>(cpu_cuda_execution_providers));
#endif
} break;

View file

@ -0,0 +1,282 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/optimizer/matmul_scale_fusion.h"
#include "onnx/defs/attr_proto_util.h"
#include "core/common/optional.h"
#include "core/framework/data_types_internal.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/graph_utils.h"
#include "core/graph/graph_viewer.h"
#include "core/optimizer/utils.h"
namespace onnxruntime {
namespace {
template <typename T>
struct ExtractScalarAsFloatDispatchTarget {
Status operator()(const ONNX_NAMESPACE::TensorProto& tensor_proto, float& scalar_float) {
T scalar;
ORT_RETURN_IF_ERROR(utils::UnpackTensor(tensor_proto, &scalar, 1));
scalar_float = static_cast<float>(scalar);
return Status::OK();
}
};
optional<float> GetScalarConstantInitializer(const Graph& graph, const NodeArg& node_arg) {
const auto* initializer = graph_utils::GetConstantInitializer(graph, node_arg.Name());
if (!initializer) {
// not a constant
return {};
}
const auto* shape = node_arg.Shape();
ORT_ENFORCE(
shape,
"Constant initializer NodeArg shape should not be null. NodeArg: ", node_arg.Name());
if (utils::GetTensorShapeFromTensorShapeProto(*shape).Size() != 1) {
// not a scalar
return {};
}
float scalar;
utils::MLTypeCallDispatcherRet<
Status, ExtractScalarAsFloatDispatchTarget,
uint32_t, uint64_t, int32_t, int64_t, MLFloat16, float, double>
dispatcher{initializer->data_type()};
ORT_THROW_IF_ERROR(dispatcher.Invoke(*initializer, scalar));
return {scalar};
}
// gets the scale value and its input index if node is a fusable scale (Mul or Div by scalar constant)
optional<std::pair<float, int>> GetScaleFromNode(
const Graph& graph, const Node& scale_node,
const std::unordered_set<std::string>& excluded_initializer_names) {
const auto is_excluded_initializer =
[&excluded_initializer_names](const NodeArg& node_arg) {
return excluded_initializer_names.find(node_arg.Name()) != excluded_initializer_names.end();
};
if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Div", {7})) {
// (x / scale_reciprocal)
const auto div_inputs = scale_node.InputDefs();
ORT_ENFORCE(div_inputs.size() == 2);
const int scale_reciprocal_arg_index = 1;
const NodeArg& scale_reciprocal_node_arg = *div_inputs[scale_reciprocal_arg_index];
if (is_excluded_initializer(scale_reciprocal_node_arg)) return {};
const auto divisor = GetScalarConstantInitializer(graph, scale_reciprocal_node_arg);
if (!divisor.has_value()) return {};
return {std::make_pair(1.0f / divisor.value(), scale_reciprocal_arg_index)};
}
if (graph_utils::IsSupportedOptypeVersionAndDomain(scale_node, "Mul", {7})) {
// (x * scale) or (scale * x)
const auto mul_inputs = scale_node.InputDefs();
ORT_ENFORCE(mul_inputs.size() == 2);
for (int scale_arg_index = 0; scale_arg_index < 2; ++scale_arg_index) {
const NodeArg& scale_node_arg = *mul_inputs[scale_arg_index];
if (is_excluded_initializer(scale_node_arg)) continue;
const auto multiplier = GetScalarConstantInitializer(graph, scale_node_arg);
if (!multiplier.has_value()) continue;
return {std::make_pair(multiplier.value(), scale_arg_index)};
}
return {};
}
return {};
}
struct ScaleMergeInfo {
// the edge from the base node to the original node
Node::EdgeConstIterator node_to_merge_edge;
// the scale of the original node
float scale;
// the index of the input or output def on the original node
// this def is moved to the fused node
// for a leading scale (scale -> MatMul), it will be the unscaled input
// for a trailing scale (MatMul -> scale), it will be the scaled output
int node_to_merge_def_index;
// the index of the input or output def on the fused node
int fused_node_def_index;
};
std::vector<ScaleMergeInfo> GetInputNodeMerges(
Graph& graph, Node& node,
const std::unordered_set<std::string>& excluded_initializer_names) {
std::vector<ScaleMergeInfo> input_node_merges{};
for (auto input_edge = node.InputEdgesBegin(); input_edge != node.InputEdgesEnd(); ++input_edge) {
const Node& input_node = input_edge->GetNode();
if (input_node.GetExecutionProviderType() != node.GetExecutionProviderType()) continue;
const auto scale_and_index = GetScaleFromNode(graph, input_node, excluded_initializer_names);
if (!scale_and_index.has_value()) continue;
// assume scale nodes have 2 input defs, so to_scale_index == 1 - scale_index
ORT_ENFORCE(input_node.InputDefs().size() == 2 && scale_and_index.value().second < 2);
const int to_scale_index = 1 - scale_and_index.value().second;
input_node_merges.push_back(
{input_edge,
scale_and_index.value().first,
to_scale_index,
input_edge->GetDstArgIndex()});
}
return input_node_merges;
}
std::vector<ScaleMergeInfo> GetOutputNodeMerges(
Graph& graph, Node& node,
const std::unordered_set<std::string>& excluded_initializer_names) {
if (!optimizer_utils::CheckOutputEdges(graph, node, 1)) {
return {};
}
std::vector<ScaleMergeInfo> output_node_merges{};
for (auto output_edge = node.OutputEdgesBegin(); output_edge != node.OutputEdgesEnd(); ++output_edge) {
const Node& output_node = output_edge->GetNode();
if (output_node.GetExecutionProviderType() != node.GetExecutionProviderType()) continue;
const auto scale_and_index = GetScaleFromNode(graph, output_node, excluded_initializer_names);
if (!scale_and_index.has_value()) continue;
ORT_ENFORCE(output_node.OutputDefs().size() == 1);
const int scaled_index = 0;
output_node_merges.push_back(
{output_edge,
scale_and_index.value().first,
scaled_index,
output_edge->GetSrcArgIndex()});
}
return output_node_merges;
}
Status ProcessNode(
Graph& graph, Node& node, bool& modified,
const std::unordered_set<std::string>& excluded_initializer_names) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(node, "TransposeScaleMatMul", {1}, kMSDomain)) {
return Status::OK();
}
const std::vector<ScaleMergeInfo> input_node_merges = GetInputNodeMerges(
graph, node, excluded_initializer_names);
const std::vector<ScaleMergeInfo> output_node_merges = GetOutputNodeMerges(
graph, node, excluded_initializer_names);
if (input_node_merges.empty() && output_node_merges.empty()) {
return Status::OK();
}
NodeAttributes fused_node_attrs =
node.OpType() == "TransposeScaleMatMul" ? node.GetAttributes() : NodeAttributes{};
{
ONNX_NAMESPACE::AttributeProto& alpha_attr = fused_node_attrs["alpha"];
float total_scale = utils::HasFloat(alpha_attr) ? alpha_attr.f() : 1.0f;
auto accumulate_scale = [&total_scale](const ScaleMergeInfo& fusion) {
total_scale *= fusion.scale;
};
std::for_each(input_node_merges.begin(), input_node_merges.end(), accumulate_scale);
std::for_each(output_node_merges.begin(), output_node_merges.end(), accumulate_scale);
alpha_attr = ONNX_NAMESPACE::MakeAttribute("alpha", total_scale);
}
auto get_mutable_node_to_merge = [&graph](const ScaleMergeInfo& merge) -> Node& {
return *graph.GetNode(merge.node_to_merge_edge->GetNode().Index());
};
std::vector<NodeArg*> fused_node_inputs = node.MutableInputDefs();
for (const auto& input_node_merge : input_node_merges) {
Node& input_node = get_mutable_node_to_merge(input_node_merge);
fused_node_inputs[input_node_merge.fused_node_def_index] =
input_node.MutableInputDefs()[input_node_merge.node_to_merge_def_index];
}
std::vector<NodeArg*> fused_node_outputs = node.MutableOutputDefs();
for (const auto& output_node_merge : output_node_merges) {
Node& output_node = get_mutable_node_to_merge(output_node_merge);
fused_node_outputs[output_node_merge.fused_node_def_index] =
output_node.MutableOutputDefs()[output_node_merge.node_to_merge_def_index];
}
Node& matmul_scale_node = graph.AddNode(
graph.GenerateNodeName(node.Name() + "_FusedMatMulAndScale"),
"TransposeScaleMatMul",
"Fused MatMul and Scale",
fused_node_inputs,
fused_node_outputs,
&fused_node_attrs,
kMSDomain);
matmul_scale_node.SetExecutionProviderType(node.GetExecutionProviderType());
{
std::vector<std::reference_wrapper<Node>> nodes_to_remove{node};
for (const auto& input_node_merge : input_node_merges) {
// remove merged input node's output edge
auto input_node_edge = input_node_merge.node_to_merge_edge;
Node& input_node = get_mutable_node_to_merge(input_node_merge);
graph.RemoveEdge(
input_node.Index(), node.Index(),
input_node_edge->GetSrcArgIndex(), input_node_edge->GetDstArgIndex());
// only remove merged input node if it has no more outputs
if (!optimizer_utils::CheckOutputEdges(graph, input_node, 0)) continue;
nodes_to_remove.push_back(input_node);
}
for (const auto& output_node_merge : output_node_merges) {
nodes_to_remove.push_back(get_mutable_node_to_merge(output_node_merge));
}
for (Node& node_to_remove : nodes_to_remove) {
graph_utils::RemoveNodeOutputEdges(graph, node_to_remove);
graph.RemoveNode(node_to_remove.Index());
}
}
modified = true;
return Status::OK();
}
} // namespace
Status MatMulScaleFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger)
const {
GraphViewer graph_viewer{graph};
const auto node_indices = graph_viewer.GetNodesInTopologicalOrder();
for (const auto node_index : node_indices) {
auto* node = graph.GetNode(node_index);
if (!node) continue;
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger));
ORT_RETURN_IF_ERROR(ProcessNode(graph, *node, modified, excluded_initializer_names_));
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/optimizer/graph_transformer.h"
namespace onnxruntime {
/**
* Fuses MatMul with surrounding scales (multiplies or divides) by a constant
* scalar into TransposeScaleMatMul.
*
* For example, given matrices A and B and constant scalars t, u, and v:
* Mul(v, MatMul(Mul(t, A), Mul(u, B))
* -> TransposeScaleMatMul(A, B, alpha=t*u*v)
*/
class MatMulScaleFusion : public GraphTransformer {
public:
/**
* Constructor.
* @param compatible_execution_providers The compatible execution providers.
* @param excluded_initializer_names Fusion will be skipped on scales by any
* of the named initializers.
*/
MatMulScaleFusion(
const std::unordered_set<std::string>& compatible_execution_providers = {},
const std::unordered_set<std::string>& excluded_initializer_names = {})
: GraphTransformer{"MatMulScaleFusion", compatible_execution_providers},
excluded_initializer_names_{excluded_initializer_names} {
}
private:
Status ApplyImpl(
Graph& graph, bool& modified,
int graph_level, const logging::Logger& logger) const override;
const std::unordered_set<std::string> excluded_initializer_names_;
};
} // namespace onnxruntime

View file

@ -71,7 +71,7 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
if ((!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(node, "TransposeMatMul", {1}, kMSDomain)) ||
!graph_utils::IsSupportedOptypeVersionAndDomain(node, "TransposeScaleMatMul", {1}, kMSDomain)) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
continue;
}
@ -104,16 +104,16 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
const std::vector<NodeArg*> output_defs{node.MutableOutputDefs()[0]};
Node& matmul_node = graph.AddNode(graph.GenerateNodeName("MatMul_With_Transpose"),
"TransposeMatMul",
"TransposeScaleMatMul",
"fused MatMul and Transpose ",
input_defs,
output_defs, {}, kMSDomain);
bool transpose_left = (left != nullptr);
if (node.OpType() == "TransposeMatMul") {
if (node.OpType() == "TransposeScaleMatMul") {
transpose_left ^= static_cast<bool>(node.GetAttributes().at("transA").i());
}
bool transpose_right = (right != nullptr);
if (node.OpType() == "TransposeMatMul") {
if (node.OpType() == "TransposeScaleMatMul") {
transpose_right ^= static_cast<bool>(node.GetAttributes().at("transB").i());
}
matmul_node.AddAttribute("transA", static_cast<int64_t>(transpose_left));

View file

@ -96,8 +96,8 @@ Status MatMul<T>::ComputeInternal(OpKernelContext* ctx) const {
if (Y->Shape().Size() == 0)
return Status::OK();
CudaT one = ToCudaType<T>::FromFloat(1.0f);
CudaT zero = ToCudaType<T>::FromFloat(0.0f);
const CudaT alpha = ToCudaType<T>::FromFloat(alpha_);
const CudaT zero = ToCudaType<T>::FromFloat(0.0f);
cublasOperation_t transA = transa ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t transB = transb ? CUBLAS_OP_T : CUBLAS_OP_N;
@ -114,7 +114,7 @@ Status MatMul<T>::ComputeInternal(OpKernelContext* ctx) const {
static_cast<int>(helper.N()),
static_cast<int>(helper.M()),
static_cast<int>(helper.K()),
&one,
&alpha,
reinterpret_cast<const CudaT*>(right_X->template Data<T>()),
ldb,
reinterpret_cast<const CudaT*>(left_X->template Data<T>()),
@ -132,7 +132,7 @@ Status MatMul<T>::ComputeInternal(OpKernelContext* ctx) const {
static_cast<int>(helper.N()),
static_cast<int>(helper.M()),
static_cast<int>(helper.K()),
&one,
&alpha,
reinterpret_cast<const CudaT*>(right_X->template Data<T>()),
ldb,
stride_B,
@ -167,7 +167,7 @@ Status MatMul<T>::ComputeInternal(OpKernelContext* ctx) const {
static_cast<int>(helper.N()),
static_cast<int>(helper.M()),
static_cast<int>(helper.K()),
&one,
&alpha,
right_arrays.GpuPtr(),
ldb,
left_arrays.GpuPtr(),

View file

@ -13,16 +13,18 @@ class MatMul final : public CudaKernel {
public:
MatMul(const OpKernelInfo& info)
: CudaKernel(info) {
trans_A_ = info.GetAttrOrDefault<int64_t>("transA", 0);
trans_B_ = info.GetAttrOrDefault<int64_t>("transB", 0);
: CudaKernel(info),
alpha_{info.GetAttrOrDefault<float>("alpha", 1.0f)},
trans_A_{info.GetAttrOrDefault<int64_t>("transA", 0) != 0},
trans_B_{info.GetAttrOrDefault<int64_t>("transB", 0) != 0} {
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
bool trans_A_;
bool trans_B_;
const float alpha_;
const bool trans_A_;
const bool trans_B_;
};
} // namespace cuda
} // namespace onnxruntime

View file

@ -131,10 +131,10 @@ void ProcessInputs(const std::vector<int64_t>& input_dims, const std::vector<T>&
}
template <typename T>
void RunTransposeMatMulTest(int32_t opset_version = 7, bool transa = false, bool transb = false) {
void RunTransposeScaleMatMulTest(int32_t opset_version = 7, bool transa = false, bool transb = false, float alpha = 1.0f) {
std::vector<T> common_input_vals{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
for (auto t : GenerateSimpleTestCases<T>()) {
OpTester test("TransposeMatMul", opset_version, onnxruntime::kMSDomain);
OpTester test("TransposeScaleMatMul", opset_version, onnxruntime::kMSDomain);
std::vector<int64_t> input0_dims(t.input0_dims);
std::vector<T> input0_vals;
@ -150,6 +150,13 @@ void RunTransposeMatMulTest(int32_t opset_version = 7, bool transa = false, bool
test.AddAttribute("transA", (int64_t)transa);
test.AddAttribute("transB", (int64_t)transb);
test.AddAttribute("alpha", alpha);
if (alpha != 1.0f) {
std::transform(
t.expected_vals.begin(), t.expected_vals.end(), t.expected_vals.begin(),
[alpha](const T& val) -> T { return alpha * val; });
}
test.AddOutput<T>("Y", t.expected_dims, t.expected_vals);
@ -159,25 +166,31 @@ void RunTransposeMatMulTest(int32_t opset_version = 7, bool transa = false, bool
}
TEST(TransposeMatMulOpTest, FloatTypeNoTranspose) {
RunTransposeMatMulTest<float>(1);
RunTransposeScaleMatMulTest<float>(1);
}
#ifdef USE_CUDA // double support only implemented in CUDA kernel
TEST(TransposeMatMulOpTest, DoubleTypeNoTranspose) {
RunTransposeMatMulTest<double>(1);
RunTransposeScaleMatMulTest<double>(1);
}
#endif
TEST(TransposeMatMulOpTest, FloatTypeTransposeA) {
RunTransposeMatMulTest<float>(1, true, false);
RunTransposeScaleMatMulTest<float>(1, true, false);
}
TEST(TransposeMatMulOpTest, FloatTypeTransposeB) {
RunTransposeMatMulTest<float>(1, false, true);
RunTransposeScaleMatMulTest<float>(1, false, true);
}
TEST(TransposeMatMulOpTest, FloatTypeTransposeAB) {
RunTransposeMatMulTest<float>(1, true, true);
RunTransposeScaleMatMulTest<float>(1, true, true);
}
TEST(TransposeMatMulOpTest, FloatTypeScale) {
RunTransposeScaleMatMulTest<float>(1, false, false, 0.5f);
RunTransposeScaleMatMulTest<float>(1, true, false, 2.0f);
RunTransposeScaleMatMulTest<float>(1, true, true, 4.0f);
}
} // namespace transpose_matmul

View file

@ -37,6 +37,7 @@
#include "core/optimizer/initializer.h"
#include "core/optimizer/layer_norm_fusion.h"
#include "core/optimizer/matmul_add_fusion.h"
#include "core/optimizer/matmul_scale_fusion.h"
#include "core/optimizer/matmul_transpose_fusion.h"
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
@ -702,7 +703,7 @@ TEST_F(GraphTransformationTests, TransposeMatmulFusion) {
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Transpose"] == 0);
ASSERT_TRUE(op_to_count["MatMul"] == 0);
ASSERT_TRUE(op_to_count["TransposeMatMul"] == 1);
ASSERT_TRUE(op_to_count["TransposeScaleMatMul"] == 1);
}
TEST_F(GraphTransformationTests, TransposeMatmulFusionOnTwoTranspose) {
@ -719,10 +720,10 @@ TEST_F(GraphTransformationTests, TransposeMatmulFusionOnTwoTranspose) {
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Transpose"] == 0);
ASSERT_TRUE(op_to_count["MatMul"] == 0);
ASSERT_TRUE(op_to_count["TransposeMatMul"] == 1);
ASSERT_TRUE(op_to_count["TransposeScaleMatMul"] == 1);
auto& node = *graph.Nodes().begin();
ASSERT_TRUE(node.OpType() == "TransposeMatMul");
ASSERT_TRUE(node.OpType() == "TransposeScaleMatMul");
ASSERT_TRUE(static_cast<bool>(node.GetAttributes().at("transA").i()));
ASSERT_TRUE(static_cast<bool>(node.GetAttributes().at("transB").i()));
}
@ -741,10 +742,10 @@ TEST_F(GraphTransformationTests, TransposeMatmulFusionOnThreeTranspose) {
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Transpose"] == 0);
ASSERT_TRUE(op_to_count["MatMul"] == 0);
ASSERT_TRUE(op_to_count["TransposeMatMul"] == 1);
ASSERT_TRUE(op_to_count["TransposeScaleMatMul"] == 1);
auto& node = *graph.Nodes().begin();
ASSERT_TRUE(node.OpType() == "TransposeMatMul");
ASSERT_TRUE(node.OpType() == "TransposeScaleMatMul");
ASSERT_FALSE(static_cast<bool>(node.GetAttributes().at("transA").i()));
ASSERT_TRUE(static_cast<bool>(node.GetAttributes().at("transB").i()));
}
@ -763,7 +764,7 @@ TEST_F(GraphTransformationTests, TransposeMatmulNoFusionOnInvalidPerm) {
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Transpose"] == 1);
ASSERT_TRUE(op_to_count["MatMul"] == 1);
ASSERT_TRUE(op_to_count["TransposeMatMul"] == 0);
ASSERT_TRUE(op_to_count["TransposeScaleMatMul"] == 0);
}
TEST_F(GraphTransformationTests, Gemm_LeakyRelu_Fusion) {
@ -2824,5 +2825,111 @@ TEST_F(GraphTransformationTests, ComputationReductionTransformer_GatherND_E2E) {
}
#endif
#ifndef DISABLE_CONTRIB_OPS
template <typename GraphTransformationCheckFn>
static void TestMatMulScaleFusion(
const PathString& model_path, const Logger& logger,
GraphTransformationCheckFn graph_transformation_check,
const std::unordered_set<std::string>& excluded_initializer_names = {}) {
SCOPED_TRACE(ORT_TSTR("model path: ") + model_path);
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(model_path, model, nullptr, logger));
Graph& graph = model->MainGraph();
auto original_op_counts = CountOpsInGraph(graph);
onnxruntime::GraphTransformerManager graph_transformer_manager{5};
ASSERT_STATUS_OK(graph_transformer_manager.Register(
make_unique<MatMulScaleFusion>(std::unordered_set<std::string>{}, excluded_initializer_names),
TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformer_manager.ApplyTransformers(graph, TransformerLevel::Level2, logger));
auto transformed_op_counts = CountOpsInGraph(graph);
graph_transformation_check(graph, original_op_counts, transformed_op_counts);
}
TEST_F(GraphTransformationTests, MatMulScaleFusionFusableModels) {
const std::vector<PathString> one_fusion_model_paths{
MODEL_FOLDER "fusion/matmul_scale_in0.onnx",
MODEL_FOLDER "fusion/matmul_scale_in0_in1.onnx",
MODEL_FOLDER "fusion/matmul_scale_in0_in1_out.onnx",
MODEL_FOLDER "fusion/matmul_scale_transposescalematmul_in0_in1_out.onnx",
};
for (const auto& path : one_fusion_model_paths) {
TestMatMulScaleFusion(
path, *logger_,
[](const Graph& graph,
std::map<std::string, int> original_op_counts,
std::map<std::string, int> transformed_op_counts) {
EXPECT_EQ(transformed_op_counts["Mul"], 0);
EXPECT_EQ(transformed_op_counts["Div"], 0);
EXPECT_EQ(transformed_op_counts["MatMul"], 0);
EXPECT_EQ(transformed_op_counts["TransposeScaleMatMul"], 1);
// check combined scale, individual scales should all have the same value
const float scale_value = 3.0f;
const int num_scales =
original_op_counts["Mul"] + original_op_counts["Div"] + original_op_counts["TransposeScaleMatMul"];
auto fused_node = std::find_if(
graph.Nodes().cbegin(), graph.Nodes().cend(),
[](const Node& node) { return node.OpType() == "TransposeScaleMatMul"; });
ASSERT_NE(fused_node, graph.Nodes().cend());
auto alpha_attr = fused_node->GetAttributes().find("alpha");
ASSERT_NE(alpha_attr, fused_node->GetAttributes().end());
EXPECT_EQ(alpha_attr->second.f(), pow(scale_value, num_scales));
});
}
}
TEST_F(GraphTransformationTests, MatMulScaleFusionUnfusableModels) {
const std::vector<PathString> unfusable_model_paths{
MODEL_FOLDER "fusion/matmul_scale_unfusable_div_not_scale.onnx",
MODEL_FOLDER "fusion/matmul_scale_unfusable_scale_not_scalar.onnx",
MODEL_FOLDER "fusion/matmul_scale_unfusable_scale_not_constant.onnx",
};
for (const auto& path : unfusable_model_paths) {
TestMatMulScaleFusion(
path, *logger_,
[](const Graph&,
const std::map<std::string, int>& original_op_counts,
const std::map<std::string, int>& transformed_op_counts) {
EXPECT_EQ(original_op_counts, transformed_op_counts);
});
}
}
TEST_F(GraphTransformationTests, MatMulScaleFusionReusedInputScale) {
TestMatMulScaleFusion(
MODEL_FOLDER "fusion/matmul_scale_reused_input_scale.onnx", *logger_,
[](const Graph&,
const std::map<std::string, int>&,
std::map<std::string, int> transformed_op_counts) {
EXPECT_EQ(transformed_op_counts["Mul"], 0);
EXPECT_EQ(transformed_op_counts["Div"], 0);
EXPECT_EQ(transformed_op_counts["MatMul"], 0);
EXPECT_EQ(transformed_op_counts["TransposeScaleMatMul"], 2);
});
}
TEST_F(GraphTransformationTests, MatMulScaleFusionExcludedInitializerName) {
TestMatMulScaleFusion(
MODEL_FOLDER "fusion/matmul_scale_in0.onnx", *logger_,
[](const Graph&,
const std::map<std::string, int>& original_op_counts,
const std::map<std::string, int>& transformed_op_counts) {
EXPECT_EQ(original_op_counts, transformed_op_counts);
},
{"scale"});
}
#endif
} // namespace test
} // namespace onnxruntime

View file

@ -0,0 +1,186 @@
import onnx
from onnx import helper
from onnx import TensorProto
from onnx import OperatorSetIdProto
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 12
# The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
onnxdomain.domain = ""
msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = "com.microsoft"
opsets = [onnxdomain, msdomain]
scale_value = 3.0
def save(model_path, nodes, inputs, outputs, initializers):
graph = helper.make_graph(
nodes,
"MatMulScaleTest",
inputs, outputs, initializers)
model = helper.make_model(
graph, opset_imports=opsets, producer_name="onnxruntime-test")
onnx.save(model, model_path)
def gen(model_path,
use_transpose_matmul,
scale_input_0, scale_input_1, scale_output):
matmul_op = "TransposeScaleMatMul" if use_transpose_matmul else "MatMul"
matmul_domain = "com.microsoft" if use_transpose_matmul else ""
matmul_attrs = {"alpha": scale_value} if use_transpose_matmul else {}
nodes = []
if scale_input_0:
nodes.append(helper.make_node(
"Mul", ["input_0", "scale"], ["scaled_input_0"], "scale input_0"))
if scale_input_1:
nodes.append(helper.make_node(
"Div", ["input_1", "scale_reciprocal"], ["scaled_input_1"], "scale input_1"))
nodes.append(helper.make_node(
matmul_op,
[
"scaled_input_0" if scale_input_0 else "input_0",
"scaled_input_1" if scale_input_1 else "input_1"
],
[
"unscaled_output" if scale_output else "output"
],
matmul_op,
"",
matmul_domain,
**matmul_attrs))
if scale_output:
nodes.append(helper.make_node(
"Mul", ["scale", "unscaled_output"], ["output"], "scale output"))
initializers = [
helper.make_tensor("scale", TensorProto.FLOAT, [], [scale_value]),
helper.make_tensor("scale_reciprocal",
TensorProto.FLOAT, [], [1/scale_value])
]
inputs = [
helper.make_tensor_value_info(
"input_0", TensorProto.FLOAT, [2, 'M', 'K']),
helper.make_tensor_value_info(
"input_1", TensorProto.FLOAT, [2, 'K', 'N'])
]
outputs = [
helper.make_tensor_value_info(
"output", TensorProto.FLOAT, [2, 'M', 'N'])
]
save(model_path, nodes, inputs, outputs, initializers)
gen("matmul_scale_in0.onnx", False, True, False, False)
gen("matmul_scale_in0_in1.onnx", False, True, True, False)
gen("matmul_scale_in0_in1_out.onnx", False, True, True, True)
gen("matmul_scale_transposescalematmul_in0_in1_out.onnx", True, True, True, True)
UNFUSABLE_DIV_NOT_SCALE = 0
UNFUSABLE_SCALE_NOT_SCALAR = 1
UNFUSABLE_SCALE_NOT_CONSTANT = 2
def gen_unfusable(model_path, unfusable_type):
matmul_op = "MatMul"
if unfusable_type == UNFUSABLE_DIV_NOT_SCALE:
scale_node = helper.make_node(
"Div", ["scale", "input_0"], ["scaled_input_0"], "scale input_0")
elif unfusable_type == UNFUSABLE_SCALE_NOT_SCALAR:
scale_node = helper.make_node(
"Mul", ["scale_non_scalar", "input_0"], ["scaled_input_0"], "scale input_0")
elif unfusable_type == UNFUSABLE_SCALE_NOT_CONSTANT:
scale_node = helper.make_node(
"Mul", ["input_0", "input_0"], ["scaled_input_0"], "scale input_0")
else:
raise ValueError("Invalid unfusable_type: {}".format(unfusable_type))
nodes = [
scale_node,
helper.make_node(
matmul_op, ["scaled_input_0", "input_1"], ["output"], matmul_op)
]
initializers = [
helper.make_tensor("scale", TensorProto.FLOAT, [], [scale_value]),
helper.make_tensor("scale_non_scalar", TensorProto.FLOAT,
[2, 1, 1], [scale_value, scale_value])
]
inputs = [
helper.make_tensor_value_info(
"input_0", TensorProto.FLOAT, [2, 'M', 'K']),
helper.make_tensor_value_info(
"input_1", TensorProto.FLOAT, [2, 'K', 'N'])
]
outputs = [
helper.make_tensor_value_info(
"output", TensorProto.FLOAT, [2, 'M', 'N'])
]
save(model_path, nodes, inputs, outputs, initializers)
gen_unfusable("matmul_scale_unfusable_div_not_scale.onnx",
UNFUSABLE_DIV_NOT_SCALE)
gen_unfusable("matmul_scale_unfusable_scale_not_scalar.onnx",
UNFUSABLE_SCALE_NOT_SCALAR)
gen_unfusable("matmul_scale_unfusable_scale_not_constant.onnx",
UNFUSABLE_SCALE_NOT_CONSTANT)
def gen_reused_input_scale(model_path):
matmul_op = "MatMul"
nodes = [
helper.make_node(
"Mul", ["input_0", "scale"], ["scaled_input_0"],
"scale input_0"),
helper.make_node(
matmul_op, ["scaled_input_0", "input_1"], ["output_0"],
"MatMul input_0 and input_1"),
helper.make_node(
matmul_op, ["scaled_input_0", "input_2"], ["output_1"],
"MatMul input_0 and input_2")
]
initializers = [
helper.make_tensor("scale", TensorProto.FLOAT, [], [scale_value])
]
inputs = [
helper.make_tensor_value_info(
"input_0", TensorProto.FLOAT, [2, 'M', 'K']),
helper.make_tensor_value_info(
"input_1", TensorProto.FLOAT, [2, 'K', 'N']),
helper.make_tensor_value_info(
"input_2", TensorProto.FLOAT, [2, 'K', 'N'])
]
outputs = [
helper.make_tensor_value_info(
"output_0", TensorProto.FLOAT, [2, 'M', 'N']),
helper.make_tensor_value_info(
"output_1", TensorProto.FLOAT, [2, 'M', 'N'])
]
save(model_path, nodes, inputs, outputs, initializers)
gen_reused_input_scale("matmul_scale_reused_input_scale.onnx")

Binary file not shown.

View file

@ -254,7 +254,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetMatMulGradient) {
if (IsGradientRequiredForSrcNodeInput(0)) {
ArgDef pre_reduce_grad_0 = IA("PreReduceGrad0");
result.push_back(
NodeDef(OpDef{"TransposeMatMul", kMSDomain, 1},
NodeDef(OpDef{"TransposeScaleMatMul", kMSDomain, 1},
{GO(0), B},
{pre_reduce_grad_0},
{{"transB", MakeAttribute("transB", int64_t(1))}}));
@ -267,7 +267,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetMatMulGradient) {
if (IsGradientRequiredForSrcNodeInput(1)) {
ArgDef pre_reduce_grad_1 = IA("PreReduceGrad1");
result.push_back(
NodeDef(OpDef{"TransposeMatMul", kMSDomain, 1},
NodeDef(OpDef{"TransposeScaleMatMul", kMSDomain, 1},
{A, GO(0)},
{pre_reduce_grad_1},
{{"transA", MakeAttribute("transA", int64_t(1))}}));

View file

@ -3,51 +3,53 @@
#include "orttraining/core/optimizer/graph_transformer_utils.h"
#include "core/mlas/inc/mlas.h"
#include "core/optimizer/bias_gelu_fusion.h"
#include "core/optimizer/cast_elimination.h"
#include "core/optimizer/computation_reduction.h"
#include "core/optimizer/constant_folding.h"
#include "core/optimizer/conv_activation_fusion.h"
#include "core/optimizer/conv_add_fusion.h"
#include "core/optimizer/conv_bn_fusion.h"
#include "core/optimizer/conv_mul_fusion.h"
#include "core/optimizer/dropout_elimination.h"
#include "core/optimizer/embed_layer_norm_fusion.h"
#include "core/optimizer/expand_elimination.h"
#include "core/optimizer/fast_gelu_fusion.h"
#include "core/optimizer/free_dim_override_transformer.h"
#include "core/optimizer/gelu_approximation.h"
#include "core/optimizer/gelu_fusion.h"
#include "core/optimizer/gemm_activation_fusion.h"
#include "core/optimizer/graph_transformer_utils.h"
#include "core/optimizer/identity_elimination.h"
#include "core/optimizer/layer_norm_fusion.h"
#include "core/optimizer/matmul_add_fusion.h"
#include "core/optimizer/matmul_scale_fusion.h"
#include "core/optimizer/matmul_transpose_fusion.h"
#include "core/optimizer/nchwc_transformer.h"
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/shape_to_initializer.h"
#include "core/optimizer/skip_layer_norm_fusion.h"
#include "core/optimizer/slice_elimination.h"
#include "core/optimizer/unsqueeze_elimination.h"
#include "core/session/inference_session.h"
#include "orttraining/core/framework/distributed_run_context.h"
#include "orttraining/core/optimizer/bias_dropout_fusion.h"
#include "orttraining/core/optimizer/insert_output_rewriter.h"
#include "orttraining/core/optimizer/megatron_transformer.h"
#include "orttraining/core/optimizer/bias_dropout_fusion.h"
#include "orttraining/core/optimizer/nonzero_shape_setter.h"
#include "core/optimizer/identity_elimination.h"
#include "core/optimizer/slice_elimination.h"
#include "core/optimizer/conv_mul_fusion.h"
#include "core/optimizer/conv_bn_fusion.h"
#include "core/optimizer/conv_add_fusion.h"
#include "core/optimizer/constant_folding.h"
#include "core/optimizer/unsqueeze_elimination.h"
#include "core/optimizer/expand_elimination.h"
#include "core/optimizer/cast_elimination.h"
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/conv_activation_fusion.h"
#include "core/optimizer/gemm_activation_fusion.h"
#include "core/optimizer/matmul_add_fusion.h"
#include "core/optimizer/dropout_elimination.h"
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/shape_to_initializer.h"
#include "core/optimizer/nchwc_transformer.h"
#include "core/optimizer/free_dim_override_transformer.h"
#include "core/optimizer/gelu_fusion.h"
#include "core/optimizer/layer_norm_fusion.h"
#include "core/optimizer/skip_layer_norm_fusion.h"
#include "core/optimizer/embed_layer_norm_fusion.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/matmul_transpose_fusion.h"
#include "core/optimizer/bias_gelu_fusion.h"
#include "core/optimizer/fast_gelu_fusion.h"
#include "core/optimizer/gelu_approximation.h"
#include "core/optimizer/graph_transformer_utils.h"
#include "core/optimizer/computation_reduction.h"
#include "core/mlas/inc/mlas.h"
#include "core/session/inference_session.h"
namespace onnxruntime {
namespace training {
namespace transformer_utils {
std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(TransformerLevel level,
const std::unordered_set<std::string>& weights_to_train,
bool enable_gelu_approximation,
const std::vector<std::string>& transformers_and_rules_to_enable) {
std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
TransformerLevel level,
const std::unordered_set<std::string>& weights_to_train,
bool enable_gelu_approximation,
const std::vector<std::string>& transformers_and_rules_to_enable) {
std::vector<std::unique_ptr<GraphTransformer>> transformers;
std::unique_ptr<RuleBasedGraphTransformer> rule_transformer = nullptr;
@ -130,9 +132,11 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(T
return filtered_list;
}
std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerLevel level,
gsl::span<const FreeDimensionOverride> free_dimension_overrides,
const std::vector<std::string>& transformers_and_rules_to_enable) {
std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
TransformerLevel level,
const std::unordered_set<std::string>& weights_to_train,
gsl::span<const FreeDimensionOverride> free_dimension_overrides,
const std::vector<std::string>& transformers_and_rules_to_enable) {
std::vector<std::unique_ptr<GraphTransformer>> transformers;
std::unique_ptr<RuleBasedGraphTransformer> rule_transformer = nullptr;
switch (level) {
@ -146,6 +150,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
transformers.emplace_back(onnxruntime::make_unique<FreeDimensionOverrideTransformer>(free_dimension_overrides));
transformers.emplace_back(onnxruntime::make_unique<MatmulTransposeFusion>(l1_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<BiasDropoutFusion>(l1_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<MatMulScaleFusion>(l1_execution_providers, weights_to_train));
rule_transformer = optimizer_utils::GenerateRuleBasedGraphTransformer(level, transformers_and_rules_to_enable, l1_execution_providers);
} break;

View file

@ -14,17 +14,20 @@ namespace training {
namespace transformer_utils {
/** Generates all pre-training transformers for this level. */
std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(TransformerLevel level,
const std::unordered_set<std::string>& weights_to_train,
bool enable_gelu_approximation,
const std::vector<std::string>& rules_and_transformers_to_enable = {});
std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
TransformerLevel level,
const std::unordered_set<std::string>& weights_to_train,
bool enable_gelu_approximation,
const std::vector<std::string>& rules_and_transformers_to_enable = {});
/** Generates all predefined (both rule-based and non-rule-based) transformers for this level.
If transformers_and_rules_to_enable is not empty, it returns the intersection between the predefined transformers/rules
and the transformers_and_rules_to_enable. */
std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerLevel level,
gsl::span<const FreeDimensionOverride> free_dimension_overrides,
const std::vector<std::string>& rules_and_transformers_to_enable = {});
std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
TransformerLevel level,
const std::unordered_set<std::string>& weights_to_train,
gsl::span<const FreeDimensionOverride> free_dimension_overrides,
const std::vector<std::string>& rules_and_transformers_to_enable = {});
} // namespace transformer_utils
} // namespace training

View file

@ -109,16 +109,16 @@ bool IsRootNode(const TrainingSession::TrainingConfiguration& config) {
}
} // namespace
void TrainingSession::FilterUnusedWeights(const std::unordered_set<std::string>& weight_names_to_train,
std::unordered_set<std::string>& filtered_weight_names_to_train) {
std::unordered_set<std::string>& filtered_weight_names_to_train) {
filtered_weight_names_to_train.clear();
for (const auto& name: weight_names_to_train) {
for (const auto& name : weight_names_to_train) {
auto nodes = model_->MainGraph().GetConsumerNodes(name);
if (!nodes.empty())
filtered_weight_names_to_train.insert(name);
else
LOGS(*session_logger_, WARNING) << "Couldn't find any consumer node for weight " << name << ", exclude it from training.";
LOGS(*session_logger_, WARNING)
<< "Couldn't find any consumer node for weight " << name << ", exclude it from training.";
}
}
@ -138,7 +138,8 @@ Status TrainingSession::ConfigureForTraining(
TrainingConfigurationResult config_result{};
ORT_ENFORCE(config.distributed_config.pipeline_parallel_size > 0,
"This parameter should be 1 if there is no pipelie parallelism. Otherwise, it's the number of pipeline stages.");
"This parameter should be 1 if there is no pipeline parallelism. "
"Otherwise, it's the number of pipeline stages.");
DistributedRunContext::CreateInstance({config.distributed_config.world_rank,
config.distributed_config.world_size,
@ -272,8 +273,9 @@ Status TrainingSession::ConfigureForTraining(
}
pipeline_result.fetch_names.push_back(name);
}
pipeline_result.pipeline_stage_id = config.distributed_config.world_rank /
(config.distributed_config.data_parallel_size * config.distributed_config.horizontal_parallel_size);
pipeline_result.pipeline_stage_id =
config.distributed_config.world_rank /
(config.distributed_config.data_parallel_size * config.distributed_config.horizontal_parallel_size);
config_result.pipeline_config_result = pipeline_result;
}
@ -478,21 +480,24 @@ static Status AddGradientAccumulationNodes(Graph& graph,
gradient_accumulation_buffers.resize(gradient_argdefs.size());
std::vector<std::string> grad_acc_outputs;
for (size_t i = 0; i < gradient_argdefs.size(); ++i) {
grad_acc_outputs.push_back(BuildGradientAccumulationNode(
nodearg_name_generator, gradient_argdefs[i], gradient_accumulation_buffers[i], graph_defs, false)
.name);
grad_acc_outputs.push_back(
BuildGradientAccumulationNode(
nodearg_name_generator, gradient_argdefs[i], gradient_accumulation_buffers[i], graph_defs, false)
.name);
}
return GraphAugmenter::AugmentGraph(graph, graph_defs);
}
Status TrainingSession::ApplyTransformationsToMainGraph(const std::unordered_set<std::string>& weights_to_train, bool enable_gelu_approximation) {
Status TrainingSession::ApplyTransformationsToMainGraph(
const std::unordered_set<std::string>& weights_to_train, bool enable_gelu_approximation) {
GraphTransformerManager graph_transformation_mgr{1};
AddPreTrainingTransformers(graph_transformation_mgr, weights_to_train, enable_gelu_approximation);
// apply transformers
Graph& graph = model_->MainGraph();
for (int i = static_cast<int>(TransformerLevel::Level1); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, static_cast<TransformerLevel>(i), *session_logger_));
ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(
graph, static_cast<TransformerLevel>(i), *session_logger_));
}
return common::Status::OK();
}
@ -530,7 +535,8 @@ void TrainingSession::AddPredefinedTransformers(GraphTransformerManager& transfo
const std::vector<std::string>& custom_list) {
auto add_transformers = [&](TransformerLevel level) {
// Generate and register transformers for level
auto transformers_to_register = transformer_utils::GenerateTransformers(level, GetSessionOptions().free_dimension_overrides, custom_list);
auto transformers_to_register = transformer_utils::GenerateTransformers(
level, weights_to_train_, GetSessionOptions().free_dimension_overrides, custom_list);
for (auto& entry : transformers_to_register) {
transformer_manager.Register(std::move(entry), level);
}
@ -636,10 +642,12 @@ Status TrainingSession::ConfigureLossFunction(
return DoPostLoadProcessing(*model_);
}
Status TrainingSession::EnableMixedPrecision(const std::unordered_set<std::string>& weights_to_train,
bool use_fp16_initializer,
std::unordered_map<std::string, NodeArg*>& fp32_weight_name_to_fp16_node_arg) {
ORT_RETURN_IF_ERROR(TransformGraphForMixedPrecision(model_->MainGraph(), weights_to_train, use_fp16_initializer, fp32_weight_name_to_fp16_node_arg));
Status TrainingSession::EnableMixedPrecision(
const std::unordered_set<std::string>& weights_to_train,
bool use_fp16_initializer,
std::unordered_map<std::string, NodeArg*>& fp32_weight_name_to_fp16_node_arg) {
ORT_RETURN_IF_ERROR(TransformGraphForMixedPrecision(
model_->MainGraph(), weights_to_train, use_fp16_initializer, fp32_weight_name_to_fp16_node_arg));
std::unordered_set<std::string> fp16_weight_initializer_names{};
std::transform(
@ -797,7 +805,8 @@ Status TrainingSession::Save(const PathString& model_uri, TrainingSession::SaveO
auto status = Model::Save(*new_model, model_uri);
if (!status.IsOK()) {
LOGS(*session_logger_, WARNING) << "Error when saving model " << ToMBString(model_uri) << " : " << status.ErrorMessage();
LOGS(*session_logger_, WARNING)
<< "Error when saving model " << ToMBString(model_uri) << " : " << status.ErrorMessage();
}
return status;
@ -817,8 +826,8 @@ bool TrainingSession::IsGraphOutputFp32Node(const std::string& output_name) cons
ORT_ENFORCE(output_producer_node != nullptr, "Output: " + output_name + " is not produced by any node.");
for (auto output : output_producer_node->OutputDefs()) {
if (output->Name() == output_name && output->TypeAsProto() != nullptr && output->TypeAsProto()->has_tensor_type()
&& output->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
if (output->Name() == output_name && output->TypeAsProto() != nullptr && output->TypeAsProto()->has_tensor_type() &&
output->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
return true;
}
}
@ -835,24 +844,22 @@ common::Status TrainingSession::Run(const RunOptions& run_options, IOBinding& io
for (auto& drop_ratio : dropout_eval_feeds_) {
OrtValue feed_value;
// We allocate on CPU first, copy will be taken care of downstream.
auto cpu_allocator = GetSessionState().GetExecutionProviders()
.Get(onnxruntime::kCpuExecutionProvider)
->GetAllocator(0, OrtMemTypeDefault);
const auto* cpu_ep = GetSessionState().GetExecutionProviders().Get(onnxruntime::kCpuExecutionProvider);
const auto cpu_allocator = cpu_ep->GetAllocator(0, OrtMemTypeDefault);
feed_value = onnxruntime::MakeScalarMLValue<float>(cpu_allocator, 0.f, true /*is_1d*/);
// Bind new feed to graph input.
new_feeds.emplace_back(drop_ratio, feed_value);
}
}
else {
} else {
auto& input_names = io_binding.GetInputNames();
if (GetSessionState().GetInputNodeInfoMap().find(training_mode_string_) != GetSessionState().GetInputNodeInfoMap().end() &&
if (GetSessionState().GetInputNodeInfoMap().find(training_mode_string_) !=
GetSessionState().GetInputNodeInfoMap().end() &&
std::find(input_names.begin(), input_names.end(), training_mode_string_) == input_names.end()) {
// Set training_mode input to false
OrtValue training_mode_feed_value;
// We allocate on CPU first, copy will be taken care of downstream.
auto cpu_allocator = GetSessionState().GetExecutionProviders()
.Get(onnxruntime::kCpuExecutionProvider)
->GetAllocator(0, OrtMemTypeDefault);
const auto* cpu_ep = GetSessionState().GetExecutionProviders().Get(onnxruntime::kCpuExecutionProvider);
const auto cpu_allocator = cpu_ep->GetAllocator(0, OrtMemTypeDefault);
training_mode_feed_value = onnxruntime::MakeScalarMLValue<bool>(cpu_allocator, false, true /*is_1d*/);
new_feeds.emplace_back(training_mode_string_, training_mode_feed_value);
}
@ -879,18 +886,18 @@ Status TrainingSession::SetEvalFeedNames() {
for (auto& node : graph.Nodes()) {
auto it = Nodes_Need_Eval_Feeds.find(node.OpType());
if(it != Nodes_Need_Eval_Feeds.cend()) {
if (it != Nodes_Need_Eval_Feeds.cend()) {
// The opset is < 12, add each ratio input to graph inputs for overriding.
// Needs to be removed when TrainableDropout is deprecated.
if(it->compare("TrainableDropout") == 0) {
if (it->compare("TrainableDropout") == 0) {
auto& ratio_name = node.InputDefs()[1]->Name();
dropout_eval_feeds_.insert(ratio_name);
ORT_ENFORCE(model_->MainGraph().GetProducerNode(ratio_name) == nullptr,
"Input: " + ratio_name + " should not have any producer node.");
"Input: " + ratio_name + " should not have any producer node.");
defs.AddGraphInputs({ratio_name});
}
// Found an opset-12 dropout node, replace initializer name.
else if(node.InputArgCount().size() > 2) {
else if (node.InputArgCount().size() > 2) {
auto& mode_input = node.MutableInputDefs()[2];
const ONNX_NAMESPACE::TensorProto* mode_initializer = nullptr;
if (!graph.GetInitializedTensor(training_mode_string_, mode_initializer)) {
@ -910,7 +917,7 @@ Status TrainingSession::SetEvalFeedNames() {
}
}
}
ORT_RETURN_IF_ERROR(GraphAugmenter::AugmentGraph(graph, defs));
return DoPostLoadProcessing(*model_);
}
@ -1052,7 +1059,8 @@ std::unordered_set<std::string> TrainingSession::GetTrainableModelInitializers(
bool proceed = std::any_of(from->InputEdgesBegin(), from->InputEdgesEnd(), is_trainable_from_to_link);
if (!proceed && session_logger_) {
VLOGS(*session_logger_, 1) << "Stopping training parameters discovery traversal from " << from->Name() << " to " << to->Name() << std::endl;
VLOGS(*session_logger_, 1)
<< "Stopping training parameters discovery traversal from " << from->Name() << " to " << to->Name();
}
return !proceed;

View file

@ -22,7 +22,7 @@ TEST(GraphTransformerUtilsTestsForTraining, TestGenerateGraphTransformers) {
std::string l2_transformer = "ConvActivationFusion";
std::vector<std::string> custom_list = {l1_rule1, l1_transformer, l2_transformer};
auto transformers = training::transformer_utils::GenerateTransformers(TransformerLevel::Level1, {}, custom_list);
auto transformers = training::transformer_utils::GenerateTransformers(TransformerLevel::Level1, {}, {}, custom_list);
ASSERT_TRUE(transformers.size() == 1);
auto l1_rule_transformer_name = optimizer_utils::GenerateRuleBasedTransformerName(TransformerLevel::Level1);
@ -34,7 +34,7 @@ TEST(GraphTransformerUtilsTestsForTraining, TestGenerateGraphTransformers) {
}
ASSERT_TRUE(rule_transformer && rule_transformer->RulesCount() == 1);
transformers = training::transformer_utils::GenerateTransformers(TransformerLevel::Level2, {}, custom_list);
transformers = training::transformer_utils::GenerateTransformers(TransformerLevel::Level2, {}, {}, custom_list);
#ifndef DISABLE_CONTRIB_OPS
ASSERT_TRUE(transformers.size() == 1);
#else