mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
MatMulScaleFusion - handling scale input (#11121)
* scale input * more condition check * alternative * per comments * fix comments Co-authored-by: Ethan Tao <ettao@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
parent
94032357e2
commit
bc296c706e
4 changed files with 55 additions and 1 deletions
|
|
@ -131,6 +131,11 @@ std::vector<ScaleMergeInfo> GetInputNodeMerges(
|
|||
ORT_ENFORCE(input_node.InputDefs().size() == 2 && scale_and_index->second < 2);
|
||||
const int to_scale_index = 1 - scale_and_index->second;
|
||||
|
||||
// check if the non-scale input is scalar
|
||||
const auto non_scale_node_arg = input_node.InputDefs()[to_scale_index];
|
||||
const auto* shape = non_scale_node_arg->Shape();
|
||||
if (shape == nullptr || shape->dim_size() == 0) continue;
|
||||
|
||||
input_node_merges.push_back(
|
||||
{input_edge,
|
||||
scale_and_index->first,
|
||||
|
|
|
|||
|
|
@ -4727,6 +4727,18 @@ TEST_F(GraphTransformationTests, MatMulScaleFusionUnsupportedInputType) {
|
|||
{kCpuExecutionProvider});
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, MatMulScaleFusionWithScaleInput) {
|
||||
TestMatMulScaleFusion(
|
||||
MODEL_FOLDER "fusion/matmul_scale_with_scale_input.onnx", *logger_,
|
||||
[](const Graph&,
|
||||
const std::map<std::string, int>&,
|
||||
std::map<std::string, int> transformed_op_counts) {
|
||||
EXPECT_EQ(transformed_op_counts["Mul"], 1);
|
||||
EXPECT_EQ(transformed_op_counts["MatMul"], 1);
|
||||
EXPECT_EQ(transformed_op_counts["com.microsoft.FusedMatMul"], 0);
|
||||
});
|
||||
}
|
||||
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
TEST_F(GraphTransformationTests, IsInfReduceSum_Test) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/isinf_reducesum.onnx";
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ opsets = [onnxdomain, msdomain]
|
|||
scale_value = 3.0
|
||||
|
||||
|
||||
def save(model_path, nodes, inputs, outputs, initializers):
|
||||
def save(model_path, nodes, inputs, outputs, initializers, opsets=opsets):
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"MatMulScaleTest",
|
||||
|
|
@ -218,3 +218,40 @@ def gen_int32(model_path):
|
|||
|
||||
|
||||
gen_int32("matmul_scale_int32.onnx")
|
||||
|
||||
|
||||
def gen_scale_input(model_path):
|
||||
|
||||
nodes = [
|
||||
helper.make_node(
|
||||
"Mul", ["input_0", "scale"], ["scaled_input_0"],
|
||||
"scale input_0"),
|
||||
helper.make_node(
|
||||
"MatMul", ["scaled_input_0", "input_1"], ["output_0"],
|
||||
"MatMul input_0 and input_1"),
|
||||
]
|
||||
|
||||
initializers = [
|
||||
helper.make_tensor("scale", TensorProto.FLOAT, [1], [1.0])
|
||||
]
|
||||
|
||||
inputs = [
|
||||
helper.make_tensor_value_info(
|
||||
"input_0", TensorProto.FLOAT, []),
|
||||
helper.make_tensor_value_info(
|
||||
"input_1", TensorProto.FLOAT, [1, 'K']),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
helper.make_tensor_value_info(
|
||||
"output_0", TensorProto.FLOAT, ['K']),
|
||||
]
|
||||
|
||||
onnxdomain = OperatorSetIdProto()
|
||||
onnxdomain.version = 14
|
||||
onnxdomain.domain = ""
|
||||
opsets = [onnxdomain, msdomain]
|
||||
save(model_path, nodes, inputs, outputs, initializers, opsets)
|
||||
|
||||
|
||||
gen_scale_input("matmul_scale_with_scale_input.onnx")
|
||||
BIN
onnxruntime/test/testdata/transform/fusion/matmul_scale_with_scale_input.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/matmul_scale_with_scale_input.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue