MatMulScaleFusion - handling scale input (#11121)

* scale input

* more condition check

* alternative

* per comments

* fix comments

Co-authored-by: Ethan Tao <ettao@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
ytaous 2022-04-14 21:54:04 -07:00 committed by GitHub
parent 94032357e2
commit bc296c706e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 55 additions and 1 deletions

View file

@ -131,6 +131,11 @@ std::vector<ScaleMergeInfo> GetInputNodeMerges(
ORT_ENFORCE(input_node.InputDefs().size() == 2 && scale_and_index->second < 2);
const int to_scale_index = 1 - scale_and_index->second;
// check if the non-scale input is scalar
const auto non_scale_node_arg = input_node.InputDefs()[to_scale_index];
const auto* shape = non_scale_node_arg->Shape();
if (shape == nullptr || shape->dim_size() == 0) continue;
input_node_merges.push_back(
{input_edge,
scale_and_index->first,

View file

@ -4727,6 +4727,18 @@ TEST_F(GraphTransformationTests, MatMulScaleFusionUnsupportedInputType) {
{kCpuExecutionProvider});
}
TEST_F(GraphTransformationTests, MatMulScaleFusionWithScaleInput) {
TestMatMulScaleFusion(
MODEL_FOLDER "fusion/matmul_scale_with_scale_input.onnx", *logger_,
[](const Graph&,
const std::map<std::string, int>&,
std::map<std::string, int> transformed_op_counts) {
EXPECT_EQ(transformed_op_counts["Mul"], 1);
EXPECT_EQ(transformed_op_counts["MatMul"], 1);
EXPECT_EQ(transformed_op_counts["com.microsoft.FusedMatMul"], 0);
});
}
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST_F(GraphTransformationTests, IsInfReduceSum_Test) {
auto model_uri = MODEL_FOLDER "fusion/isinf_reducesum.onnx";

View file

@ -16,7 +16,7 @@ opsets = [onnxdomain, msdomain]
scale_value = 3.0
def save(model_path, nodes, inputs, outputs, initializers):
def save(model_path, nodes, inputs, outputs, initializers, opsets=opsets):
graph = helper.make_graph(
nodes,
"MatMulScaleTest",
@ -218,3 +218,40 @@ def gen_int32(model_path):
gen_int32("matmul_scale_int32.onnx")
def gen_scale_input(model_path):
nodes = [
helper.make_node(
"Mul", ["input_0", "scale"], ["scaled_input_0"],
"scale input_0"),
helper.make_node(
"MatMul", ["scaled_input_0", "input_1"], ["output_0"],
"MatMul input_0 and input_1"),
]
initializers = [
helper.make_tensor("scale", TensorProto.FLOAT, [1], [1.0])
]
inputs = [
helper.make_tensor_value_info(
"input_0", TensorProto.FLOAT, []),
helper.make_tensor_value_info(
"input_1", TensorProto.FLOAT, [1, 'K']),
]
outputs = [
helper.make_tensor_value_info(
"output_0", TensorProto.FLOAT, ['K']),
]
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 14
onnxdomain.domain = ""
opsets = [onnxdomain, msdomain]
save(model_path, nodes, inputs, outputs, initializers, opsets)
gen_scale_input("matmul_scale_with_scale_input.onnx")