This commit is contained in:
Vincent Wang 2020-11-12 07:41:50 +00:00 committed by Thiago Crepaldi
parent f6a8d2aa5f
commit 60b6e2683f

View file

@ -282,6 +282,7 @@ Status ModuleGradientGraphBuilder::Split() {
RemoveNodes(backward_graph, backward_nodes_to_remove);
// User inputs to backward graph inputs.
std::vector<const NodeArg*> backward_input_args;
for (const auto& input_name : split_graphs_info_.user_input_names) {
// Only takes those in the backward inputs.
@ -291,6 +292,11 @@ Status ModuleGradientGraphBuilder::Split() {
}
}
// Grad of user outputs to backward graph inputs.
for (const auto& output_grad_name : split_graphs_info_.backward_output_grad_names) {
backward_input_args.emplace_back(backward_graph.GetNodeArg(output_grad_name));
}
// Add initializer args to backward graph inputs if any node uses them.
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
// Some initializers will be inputs for backward graph.