mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
bugfix
This commit is contained in:
parent
f6a8d2aa5f
commit
60b6e2683f
1 changed files with 6 additions and 0 deletions
|
|
@ -282,6 +282,7 @@ Status ModuleGradientGraphBuilder::Split() {
|
|||
|
||||
RemoveNodes(backward_graph, backward_nodes_to_remove);
|
||||
|
||||
// User inputs to backward graph inputs.
|
||||
std::vector<const NodeArg*> backward_input_args;
|
||||
for (const auto& input_name : split_graphs_info_.user_input_names) {
|
||||
// Only takes those in the backward inputs.
|
||||
|
|
@ -291,6 +292,11 @@ Status ModuleGradientGraphBuilder::Split() {
|
|||
}
|
||||
}
|
||||
|
||||
// Grad of user outputs to backward graph inputs.
|
||||
for (const auto& output_grad_name : split_graphs_info_.backward_output_grad_names) {
|
||||
backward_input_args.emplace_back(backward_graph.GetNodeArg(output_grad_name));
|
||||
}
|
||||
|
||||
// Add initializer args to backward graph inputs if any node uses them.
|
||||
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
|
||||
// Some initializers will be inputs for backward graph.
|
||||
|
|
|
|||
Loading…
Reference in a new issue