diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 63f011e449..7ff3f75e3d 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1748,7 +1748,11 @@ TEST_F(GraphTransformationTests, SkipLayerNormFusion_Input_Output_Check) { // check outputs std::vector& output_defs = node.MutableOutputDefs(); +#ifdef ENABLE_TRAINING + EXPECT_EQ(node.OutputDefs().size(), 3u) << "SkipLayerNormalization number of outputs does not equal to 3. Got:" << node.OutputDefs().size(); +#else EXPECT_EQ(node.OutputDefs().size(), 1u) << "SkipLayerNormalization number of outputs does not equal to 1. Got:" << node.OutputDefs().size(); +#endif EXPECT_EQ(output_defs[0]->Name(), "19"); } else { EXPECT_EQ(node.OpType(), "MatMul") << "Unexpected node: " << node.OpType() << "," << node.Name();