diff --git a/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc b/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc index f6123cedc0..09c120e886 100644 --- a/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc +++ b/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc @@ -282,6 +282,7 @@ Status ModuleGradientGraphBuilder::Split() { RemoveNodes(backward_graph, backward_nodes_to_remove); + // User inputs to backward graph inputs. std::vector backward_input_args; for (const auto& input_name : split_graphs_info_.user_input_names) { // Only takes those in the backward inputs. @@ -291,6 +292,11 @@ Status ModuleGradientGraphBuilder::Split() { } } + // Grad of user outputs to backward graph inputs. + for (const auto& output_grad_name : split_graphs_info_.backward_output_grad_names) { + backward_input_args.emplace_back(backward_graph.GetNodeArg(output_grad_name)); + } + // Add initializer args to backward graph inputs if any node uses them. for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) { // Some initializers will be inputs for backward graph.